1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/executable_run_options.h" 25 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" 26 #include "tensorflow/compiler/xla/service/allocation_tracker.h" 27 #include "tensorflow/compiler/xla/service/backend.h" 28 #include "tensorflow/compiler/xla/service/channel_tracker.h" 29 #include "tensorflow/compiler/xla/service/compilation_cache.h" 30 #include "tensorflow/compiler/xla/service/computation_tracker.h" 31 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 32 #include "tensorflow/compiler/xla/service/executable.h" 33 #include "tensorflow/compiler/xla/service/execution_tracker.h" 34 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 37 #include "tensorflow/compiler/xla/service/session.pb.h" 38 #include "tensorflow/compiler/xla/service/user_computation.h" 39 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 40 #include "tensorflow/compiler/xla/service_interface.h" 41 #include "tensorflow/compiler/xla/statusor.h" 42 #include "tensorflow/compiler/xla/types.h" 43 #include "tensorflow/compiler/xla/xla.pb.h" 44 #include "tensorflow/compiler/xla/xla_data.pb.h" 45 #include "tensorflow/core/lib/gtl/array_slice.h" 46 #include "tensorflow/core/platform/logging.h" 47 #include "tensorflow/core/platform/macros.h" 48 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 49 50 namespace xla { 51 52 // Options to configure the service when it is created. 53 class ServiceOptions { 54 public: 55 // Set the platform backing the service, or nullptr for the default platform. 56 ServiceOptions& set_platform(perftools::gputools::Platform* platform); 57 perftools::gputools::Platform* platform() const; 58 59 // Set the number of replicas to use when compiling replicated 60 // programs. 61 ServiceOptions& set_number_of_replicas(int number_of_replicas); 62 int number_of_replicas() const; 63 64 // Sets the thread pool size for parallel execution of an individual operator. 65 ServiceOptions& set_intra_op_parallelism_threads(int num_threads); 66 int intra_op_parallelism_threads() const; 67 68 private: 69 perftools::gputools::Platform* platform_ = nullptr; 70 int number_of_replicas_ = 1; 71 int intra_op_parallelism_threads_ = -1; 72 }; 73 74 // The XLA service object, which is the same across all platforms. It maintains 75 // the service state of computations and allocations, and delegates 76 // target-specific requests to the target-specific infrastructure 77 // (target-specific compiler, StreamExecutor). 78 class Service : public ServiceInterface { 79 public: 80 // Factory method for creating a new Service. 81 static StatusOr<std::unique_ptr<Service>> NewService( 82 perftools::gputools::Platform* platform = nullptr); 83 static StatusOr<std::unique_ptr<Service>> NewService( 84 const ServiceOptions& options); 85 86 // Creates a new computation with the given name. 87 // A unique ComputationHandle is returned. 88 tensorflow::Status Computation(const ComputationRequest* arg, 89 ComputationResponse* result) override; 90 91 // Unregisters a previously-allocated global handle. 92 // 93 // If the handle given is not currently allocated, a NOT_FOUND status is 94 // returned. 95 tensorflow::Status Unregister(const UnregisterRequest* arg, 96 UnregisterResponse* result) override; 97 98 // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each 99 // element in the tuple. 100 tensorflow::Status DeconstructTuple( 101 const DeconstructTupleRequest* arg, 102 DeconstructTupleResponse* result) override; 103 104 // Modifies the provided computation so that subsequent executions 105 // will compute the provided ComputationDataHandle, rather than the 106 // last expression enqueued on that Computation. 107 tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, 108 SetReturnValueResponse* results) override; 109 110 // Executes a computation with the provided global data passed as 111 // immutable arguments. Returns global data output and execution timing. 112 tensorflow::Status Execute(const ExecuteRequest* arg, 113 ExecuteResponse* result) override; 114 115 // Executes one or more computations in parallel with the provided global data 116 // passed as immutable arguments. Returns global data output for each 117 // computation. 118 tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, 119 ExecuteParallelResponse* result) override; 120 121 // Requests one or more device handles from the target. 122 // 123 // When N device handles are requested and the number of replicas is R, at 124 // least N * R devices must be available. The devices are assigned based on 125 // the device ordinals such that the first R available devices are assigned to 126 // the first set of replicas, and the next R devices to the second set of 127 // replicas, etc. Each returned device handle represents the device with the 128 // replica id 0. 129 tensorflow::Status GetDeviceHandles( 130 const GetDeviceHandlesRequest* arg, 131 GetDeviceHandlesResponse* result) override; 132 133 // Asynchronously executes a computation with provided arguments. Invokes 134 // the provided computation with the provided global data passed as 135 // immutable arguments. Returns a handle to the execution. 136 // 137 // (Note: The corresponding function in xla::Client was removed as part of 138 // b/64116060, in an attempt to simplify our API. We're keeping this around 139 // for now in case we want to expose this to clients in a different way.) 140 tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, 141 ExecuteAsyncResponse* result) override; 142 143 // Waits until the specified execution is complete and returns the result. 144 // Calling this API multiple times with the same execution handle returns the 145 // method with an error since the execution handle is destroyed after the 146 // first call. 147 tensorflow::Status WaitForExecution( 148 const WaitForExecutionRequest* arg, 149 WaitForExecutionResponse* result) override; 150 151 // Requests that global data be transferred to the client in literal form. 152 tensorflow::Status TransferToClient( 153 const TransferToClientRequest* arg, 154 TransferToClientResponse* result) override; 155 156 // Transfers data from a literal provided by the client, into device memory. 157 tensorflow::Status TransferToServer( 158 const TransferToServerRequest* arg, 159 TransferToServerResponse* result) override; 160 161 // Transfers data from a literal provided by the client, into the Infeed 162 // buffer of the device. 163 tensorflow::Status TransferToInfeed( 164 const TransferToInfeedRequest* arg, 165 TransferToInfeedResponse* result) override; 166 167 // Transfers data from the Outfeed othe device to the literal provided by the 168 // client. 169 tensorflow::Status TransferFromOutfeed( 170 const TransferFromOutfeedRequest* arg, 171 TransferFromOutfeedResponse* result) override; 172 173 // Resets devices, clearing all existing state on all the devices associated 174 // with this service (including memory allocated on the devices). 175 // 176 // ResetDevice may only be called where no previous Execution state on the 177 // device is used by the next Execution. 178 // 179 // ResetDevice should be called before an Execution that expect the device to 180 // be in the reset state. For example, if the prior Execution modifies device 181 // state (e.g., architectural state) that the next Execution depends on. 182 tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, 183 ResetDeviceResponse* result) override; 184 185 // Tests if an expression is a compile-time constant. 186 tensorflow::Status IsConstant(const IsConstantRequest* arg, 187 IsConstantResponse* result) override; 188 189 // Computes the value of a constant expression. 190 tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, 191 ComputeConstantResponse* result) override; 192 193 // Returns the shape (with layout) of an array associated with a given data 194 // handle. 195 tensorflow::Status GetShape(const GetShapeRequest* arg, 196 GetShapeResponse* result) override; 197 198 // Returns the program shape of the computation associated with the given 199 // handle. 200 tensorflow::Status GetComputationShape( 201 const GetComputationShapeRequest* arg, 202 GetComputationShapeResponse* result) override; 203 204 ///// 205 // Computation-oriented methods. 206 207 // Enqueues an Op on the computation. 208 tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; 209 210 // Retrieves the inferred shape for a value within a computation. 211 tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, 212 GetLocalShapeResponse* result) override; 213 214 // Retrieves the statistics of a computation. 215 tensorflow::Status GetComputationStats( 216 const ComputationStatsRequest* arg, 217 ComputationStatsResponse* result) override; 218 219 // Snapshots the current state of a computation handle into a serializable 220 // protocol buffer form, so it can be loaded via 221 // LoadComputationSnapshot. 222 tensorflow::Status SnapshotComputation( 223 const SnapshotComputationRequest* arg, 224 SnapshotComputationResponse* result) override; 225 226 // Loads a computation from a serialized protocol buffer created via 227 // SnapshotComputation. 228 tensorflow::Status LoadComputationSnapshot( 229 const LoadComputationSnapshotRequest* arg, 230 LoadComputationSnapshotResponse* result) override; 231 232 // Creates a unique channel handle that can be used for Send/Recv 233 // instructions. 234 tensorflow::Status CreateChannelHandle( 235 const CreateChannelHandleRequest* arg, 236 CreateChannelHandleResponse* result) override; 237 238 // Returns the ComputationTracker of the current service instance. 239 // Only used in unit tests to access user computations from client. 240 const ComputationTracker& computation_tracker() { 241 return computation_tracker_; 242 } 243 244 // Returns the backend used to execute computations. 245 const Backend& backend() const { return *execute_backend_; } 246 Backend* mutable_backend() { return execute_backend_.get(); } 247 248 private: 249 // A private overload for Service itself, used by other methods within this 250 // class. 251 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 252 const ProgramShape& program_shape, 253 tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, 254 const ExecutionOptions& execution_options, 255 const UserComputation& user_computation); 256 257 protected: 258 friend class LocalExecutable; 259 260 // The constructor is private. Use the NewService factory to create new 261 // service objects. 262 Service(const ServiceOptions& options, 263 std::unique_ptr<Backend> execute_backend); 264 265 static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend(); 266 267 // Resolves the given argument handles in the allocation tracker and returns 268 // the corresponding allocations. The function also verifies that each 269 // allocation matches the execution platform and device ordinal. 270 StatusOr<std::vector<const ShapedBuffer*>> ResolveAndValidateArguments( 271 tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments, 272 int device_ordinal); 273 274 // Create a Hlo module config for the given program shape and arguments. 275 // execution_options is optional; if not given a default is used. 276 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig( 277 const ProgramShape& program_shape, 278 tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, 279 const ExecutionOptions* execution_options, 280 const UserComputation& user_computation); 281 282 // Builds an Executable for the given parameters. 283 // 284 // If device_allocator is not null, the compiler may use it to allocate temp 285 // buffers, which the compiler is responsible for freeing. The allocator 286 // given here need not match the allocator used when running the executable. 287 StatusOr<std::unique_ptr<Executable>> BuildExecutable( 288 const VersionedComputationHandle& versioned_handle, 289 std::unique_ptr<HloModuleConfig> module_config, Backend* backend, 290 perftools::gputools::StreamExecutor* executor, 291 DeviceMemoryAllocator* device_allocator = nullptr); 292 293 // Same as BuildExecutable() above, but builds a list of Executables for the 294 // given computations that may interact with each other. 295 StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables( 296 std::vector<VersionedComputationHandle> versioned_handles, 297 std::vector<std::unique_ptr<HloModuleConfig>> module_configs, 298 Backend* backend, 299 std::vector<std::vector<perftools::gputools::StreamExecutor*>> executors, 300 DeviceMemoryAllocator* device_allocator); 301 302 // Similar to BuildExecutable, but look in the compilation cache for the 303 // executable first. If the executable is not in the cache, it is built and 304 // inserted into the cache. 305 StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable( 306 const VersionedComputationHandle& versioned_handle, 307 std::unique_ptr<HloModuleConfig> module_config, Backend* backend, 308 perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, 309 DeviceMemoryAllocator* device_allocator = nullptr); 310 311 // Runs the given executable with the given arguments and register the result 312 // in the allocation tracker. The handle of the result from the tracker is 313 // returned. If the parameter "profile" is not null, it points to an 314 // ExecutionProfile object which will be filled in with profile data. 315 StatusOr<GlobalDataHandle> ExecuteAndRegisterResult( 316 Executable* executable, 317 const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, 318 Backend* backend, perftools::gputools::StreamExecutor* executor, 319 const string& result_tag, ExecutionProfile* profile); 320 321 // Runs the given executables with the given arguments and register the result 322 // from each executable in the allocation tracker. The handles of the result 323 // from the tracker are returned. 324 StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult( 325 tensorflow::gtl::ArraySlice<Executable*> executables, 326 tensorflow::gtl::ArraySlice<std::vector<const ShapedBuffer*>> arguments, 327 Backend* backend, 328 tensorflow::gtl::ArraySlice<DeviceHandle> device_handles, 329 tensorflow::gtl::ArraySlice<string> result_tags, 330 ExecutionProfile* profile); 331 332 // Convenience function for adding a function to a user computation. 333 template <typename RequestT, typename ResponseT> 334 tensorflow::Status AddInstruction( 335 const RequestT* arg, ResponseT* result, 336 const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>& 337 adder); 338 339 // Convenience function which checks whether the given shape_with_layout 340 // (presumably passed by the client to set the result layout) is valid for the 341 // given computation result shape. 342 tensorflow::Status ValidateResultShapeWithLayout( 343 const Shape& shape_with_layout, const Shape& result_shape) const; 344 345 // Returns the stream executors assigned to the replicas represented by the 346 // given device handle. Each device_handle is a virtual replicated device that 347 // represents a set of physical devices for the replicas. 348 StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas( 349 const Backend& backend, const DeviceHandle& device_handle) const; 350 351 Status MaybeDumpHloModule(const HloModule& module) const; 352 353 // Returns the device handle that represents the replicated device for a 354 // single computation that is not model-parallelized. 355 DeviceHandle SingleComputationDeviceHandle() const; 356 357 ServiceOptions options_; 358 359 // Tracks computations built via the API. 360 ComputationTracker computation_tracker_; 361 362 // Tracks channels created via the API. 363 ChannelTracker channel_tracker_; 364 365 // Tracks allocations made via the API and computation execution. 366 AllocationTracker allocation_tracker_; 367 368 // Tracks asynchronously launched executions via the API. 369 ExecutionTracker execution_tracker_; 370 371 // Cache containing previously built Executables. 372 CompilationCache compilation_cache_; 373 374 // Backend to compile and execute computations on. 375 // 376 // TODO(b/28616830): Support multiple backends for execution. 377 std::unique_ptr<Backend> execute_backend_; 378 379 TF_DISALLOW_COPY_AND_ASSIGN(Service); 380 }; 381 382 } // namespace xla 383 384 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_ 385