1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/client/client.h" 17 18 #include <string> 19 #include <utility> 20 21 #include "absl/memory/memory.h" 22 #include "absl/strings/str_cat.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/client/xla_computation.h" 25 #include "tensorflow/compiler/xla/debug_options_flags.h" 26 #include "tensorflow/compiler/xla/execution_options_util.h" 27 #include "tensorflow/compiler/xla/literal.h" 28 #include "tensorflow/compiler/xla/status_macros.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/protobuf.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 37 Client::Client(ServiceInterface* stub) : stub_(stub) {} 38 39 Client::~Client() = default; 40 41 StatusOr<Literal> Client::Transfer(const GlobalData& data, 42 const Shape* shape_with_layout) { 43 TransferToClientRequest request; 44 *request.mutable_data() = data.handle(); 45 if (shape_with_layout != nullptr) { 46 *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); 47 } 48 TransferToClientResponse response; 49 50 VLOG(1) << "making transfer request"; 51 VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}"; 52 Status s = stub_->TransferToClient(&request, &response); 53 VLOG(1) << "done with request"; 54 55 if (!s.ok()) { 56 return s; 57 } 58 VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}"; 59 60 if (!response.has_literal()) { 61 return FailedPrecondition( 62 "server provided response without a literal in " 63 "TransferToClient request"); 64 } 65 return Literal::CreateFromProto(*response.mutable_literal()); 66 } 67 68 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer( 69 const LiteralSlice& literal, const DeviceHandle* device_handle) { 70 TransferToServerRequest request; 71 *request.mutable_literal() = literal.ToProto(); 72 if (device_handle) { 73 *request.mutable_device_handle() = *device_handle; 74 } 75 TransferToServerResponse response; 76 77 VLOG(1) << "making transfer to server request"; 78 VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}"; 79 Status s = stub_->TransferToServer(&request, &response); 80 VLOG(1) << "done with request"; 81 82 if (!s.ok()) { 83 return s; 84 } 85 VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}"; 86 87 if (!response.has_data()) { 88 return FailedPrecondition( 89 "server provided response without a data handle in " 90 "TransferToServer request"); 91 } 92 93 return absl::make_unique<GlobalData>(stub_, response.data()); 94 } 95 96 Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, 97 const DeviceHandle* device_handle) { 98 TransferToInfeedRequest request; 99 *request.mutable_literal() = literal.ToProto(); 100 if (device_handle) { 101 *request.mutable_device_handle() = *device_handle; 102 } 103 request.set_replica_id(replica_id); 104 TransferToInfeedResponse response; 105 106 VLOG(1) << "making transfer to infeed request"; 107 VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}"; 108 Status s = stub_->TransferToInfeed(&request, &response); 109 VLOG(1) << "done with request"; 110 111 if (!s.ok()) { 112 return s; 113 } 114 VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}"; 115 return Status::OK(); 116 } 117 118 StatusOr<Literal> Client::TransferFromOutfeed( 119 const Shape* shape_with_layout, int64 replica_id, 120 const DeviceHandle* device_handle) { 121 TransferFromOutfeedRequest request; 122 if (device_handle) { 123 *request.mutable_device_handle() = *device_handle; 124 } 125 request.set_replica_id(replica_id); 126 if (shape_with_layout != nullptr) { 127 *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); 128 } 129 TransferFromOutfeedResponse response; 130 131 VLOG(1) << "making transfer from outfeed request"; 132 VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; 133 Status s = stub_->TransferFromOutfeed(&request, &response); 134 VLOG(1) << "done with request"; 135 136 if (!s.ok()) { 137 return s; 138 } 139 VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; 140 141 if (!response.has_literal()) { 142 return FailedPrecondition( 143 "server provided response without a literal in " 144 "TransferToClient request"); 145 } 146 147 return Literal::CreateFromProto(response.literal()); 148 } 149 150 Status Client::ResetDevice() { 151 ResetDeviceRequest request; 152 ResetDeviceResponse response; 153 154 VLOG(1) << "making reset device request"; 155 VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}"; 156 Status s = stub_->ResetDevice(&request, &response); 157 VLOG(1) << "done with request"; 158 159 if (!s.ok()) { 160 return s; 161 } 162 VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}"; 163 return Status::OK(); 164 } 165 166 StatusOr<Literal> Client::ExecuteAndTransfer( 167 const XlaComputation& computation, absl::Span<GlobalData* const> arguments, 168 const ExecutionOptions* execution_options, 169 ExecutionProfile* execution_profile) { 170 TF_ASSIGN_OR_RETURN( 171 std::unique_ptr<GlobalData> data, 172 Execute(computation, arguments, execution_options, execution_profile)); 173 174 absl::optional<Shape> shape_with_output_layout; 175 if (execution_options && execution_options->has_shape_with_output_layout()) { 176 shape_with_output_layout = 177 Shape(execution_options->shape_with_output_layout()); 178 } 179 return Transfer(*data, shape_with_output_layout.has_value() 180 ? &(*shape_with_output_layout) 181 : nullptr); 182 } 183 184 StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation, 185 const Layout* output_layout) const { 186 ComputeConstantGraphRequest request; 187 *request.mutable_computation() = computation.proto(); 188 if (output_layout != nullptr) { 189 *request.mutable_output_layout() = output_layout->ToProto(); 190 } 191 192 ComputeConstantResponse response; 193 194 VLOG(2) << "making compute-constant-graph request"; 195 Status s = stub_->ComputeConstantGraph(&request, &response); 196 VLOG(2) << "done with request"; 197 198 if (!s.ok()) { 199 return s; 200 } 201 202 VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; 203 204 if (!response.has_literal()) { 205 return InternalError( 206 "no computed literal in the provided response in ComputeConstantGraph " 207 "request"); 208 } 209 return Literal::CreateFromProto(response.literal()); 210 } 211 212 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) { 213 TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); 214 return XlaComputation(module.hlo().hlo_module()); 215 } 216 217 StatusOr<ExecutionHandle> Client::Compile( 218 const XlaComputation& computation, absl::Span<const Shape> argument_shapes, 219 const ExecutionOptions* execution_options) { 220 CompileRequest request; 221 *request.mutable_computation() = computation.proto(); 222 223 if (execution_options == nullptr) { 224 *request.mutable_execution_options() = CreateDefaultExecutionOptions(); 225 } else { 226 *request.mutable_execution_options() = *execution_options; 227 } 228 if (request.execution_options().device_handles_size() > 1) { 229 return InvalidArgument( 230 "Compiling with multiple device handles is not supported. Use " 231 "'Execute' instead."); 232 } 233 234 // The argument shapes affect how the computation is compiled. 235 for (const auto& arg_shape : argument_shapes) { 236 *request.add_input_shape_with_layout() = arg_shape.ToProto(); 237 } 238 239 CompileResponse response; 240 VLOG(1) << "making compile request: " << request.ShortDebugString(); 241 Status s = stub_->Compile(&request, &response); 242 VLOG(1) << "done with request"; 243 244 if (!s.ok()) { 245 return s; 246 } 247 TF_RET_CHECK(response.has_handle()); 248 return response.handle(); 249 } 250 251 StatusOr<std::unique_ptr<GlobalData>> Client::Execute( 252 const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments, 253 ExecutionProfile* execution_profile) { 254 ExecuteRequest request; 255 *request.mutable_handle() = handle; 256 for (GlobalData* argument : arguments) { 257 CHECK(argument != nullptr) << "Argument pointers must not be null."; 258 *request.add_arguments() = argument->handle(); 259 } 260 261 ExecuteResponse response; 262 VLOG(1) << "making execute request: " << request.ShortDebugString(); 263 Status s = stub_->Execute(&request, &response); 264 VLOG(1) << "done with request"; 265 266 if (!s.ok()) { 267 return s; 268 } 269 270 if (execution_profile != nullptr) { 271 *execution_profile = response.profile(); 272 } 273 274 return absl::make_unique<GlobalData>(stub_, response.output()); 275 } 276 277 StatusOr<std::unique_ptr<GlobalData>> Client::Execute( 278 const XlaComputation& computation, absl::Span<GlobalData* const> arguments, 279 const ExecutionOptions* execution_options, 280 ExecutionProfile* execution_profile) { 281 // Create an ExecutionOptions if necessary, or set its DeviceHandles. 282 absl::optional<ExecutionOptions> options_storage; 283 if (!execution_options || execution_options->device_handles().empty()) { 284 if (execution_options) { 285 options_storage.emplace(*execution_options); 286 } else { 287 options_storage.emplace(CreateDefaultExecutionOptions()); 288 } 289 execution_options = &*options_storage; 290 291 TF_ASSIGN_OR_RETURN(auto device_handles, 292 GetDeviceHandles(/*device_count=*/1)); 293 TF_RET_CHECK(!device_handles.empty()); 294 *options_storage->add_device_handles() = std::move(device_handles[0]); 295 } 296 297 std::vector<XlaComputationInstance> computation_instances = { 298 XlaComputationInstance{ 299 computation, 300 std::vector<GlobalData*>(arguments.begin(), arguments.end()), 301 *execution_options, execution_profile}}; 302 303 // Instead of invoking Compile() and Execute(), invoke 304 // Service::ExecuteParallel() to execute our one computation. Compile() 305 // caches the executable forever, which isn't what we want. 306 VLOG(1) << "Making ExecuteParallel request: " 307 << execution_options->DebugString(); 308 TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); 309 VLOG(1) << "ExecuteParallel request done."; 310 311 // The result selection is a bit hacky, but better than assuming it is 312 // device 0. 313 // 314 // TODO(b/118493728): Allow Execute to return one result per computation. 315 for (int64 i = 0; i < results.size(); i++) { 316 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); 317 if (!ShapeUtil::IsEmptyTuple(shape)) { 318 VLOG(3) << "Fetching result from device " << i << ": " 319 << ShapeUtil::HumanString(shape); 320 return std::move(results[i]); 321 } 322 } 323 TF_RET_CHECK(!results.empty()); 324 VLOG(1) << "Defaulting to device 0 result"; 325 return std::move(results[0]); 326 } 327 328 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel( 329 absl::Span<const XlaComputationInstance> computations) { 330 ExecuteGraphParallelRequest request; 331 332 for (const XlaComputationInstance& computation : computations) { 333 ExecuteGraphRequest single_request; 334 *single_request.mutable_computation() = computation.computation.proto(); 335 for (GlobalData* argument : computation.arguments) { 336 *single_request.add_arguments() = argument->handle(); 337 } 338 *single_request.mutable_execution_options() = computation.execution_options; 339 *request.add_requests() = single_request; 340 } 341 342 ExecuteParallelResponse response; 343 VLOG(1) << "making execute-graph-parallel request: " 344 << request.ShortDebugString(); 345 Status s = stub_->ExecuteGraphParallel(&request, &response); 346 VLOG(1) << "done with request"; 347 348 if (!s.ok()) { 349 return s; 350 } 351 352 std::vector<std::unique_ptr<GlobalData>> outputs; 353 for (size_t i = 0; i < response.responses_size(); ++i) { 354 outputs.push_back( 355 absl::make_unique<GlobalData>(stub_, response.responses(i).output())); 356 if (i < computations.size() && 357 computations[i].execution_profile != nullptr) { 358 *computations[i].execution_profile = response.responses(i).profile(); 359 } 360 } 361 362 return std::move(outputs); 363 } 364 365 StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles( 366 int64 device_count) { 367 if (device_count < 1) { 368 return InvalidArgument("device_count must be greater than 0"); 369 } 370 GetDeviceHandlesRequest request; 371 request.set_device_count(device_count); 372 373 GetDeviceHandlesResponse response; 374 VLOG(1) << "making get device request: " << request.ShortDebugString(); 375 Status s = stub_->GetDeviceHandles(&request, &response); 376 VLOG(1) << "done with request"; 377 378 if (!s.ok()) { 379 return s; 380 } 381 382 std::vector<DeviceHandle> device_handles; 383 for (const DeviceHandle& device_handle : response.device_handles()) { 384 device_handles.push_back(device_handle); 385 } 386 387 return device_handles; 388 } 389 390 Status Client::Unregister(const GlobalData& data) { 391 UnregisterRequest request; 392 *request.add_data() = data.handle(); 393 UnregisterResponse response; 394 395 VLOG(1) << "making unregister request"; 396 Status s = stub_->Unregister(&request, &response); 397 VLOG(1) << "done with request"; 398 399 return s; 400 } 401 402 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple( 403 const GlobalData& data) { 404 DeconstructTupleRequest request; 405 *request.mutable_tuple_handle() = data.handle(); 406 DeconstructTupleResponse response; 407 408 VLOG(1) << "making DestructTuple request"; 409 Status s = stub_->DeconstructTuple(&request, &response); 410 VLOG(1) << "done with request"; 411 412 if (!s.ok()) { 413 return s; 414 } 415 416 std::vector<std::unique_ptr<GlobalData>> handles; 417 for (auto& handle : response.element_handles()) { 418 handles.push_back(absl::make_unique<GlobalData>(stub_, handle)); 419 } 420 return std::move(handles); 421 } 422 423 StatusOr<ComputationStats> Client::GetComputationStats( 424 const XlaComputation& computation, 425 const DebugOptions& debug_options) const { 426 ComputationGraphStatsRequest request; 427 428 // TODO(b/74197823): Find a way to avoid the copy of the hlo proto. 429 *request.mutable_computation() = computation.proto(); 430 *request.mutable_debug_options() = debug_options; 431 ComputationStatsResponse response; 432 433 VLOG(1) << "making computation graph stats request"; 434 Status s = stub_->GetComputationGraphStats(&request, &response); 435 VLOG(1) << "done with request"; 436 437 if (!s.ok()) { 438 return s; 439 } 440 CHECK(response.has_stats()); 441 return response.stats(); 442 } 443 444 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape( 445 const XlaComputation& computation) { 446 TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); 447 return absl::make_unique<ProgramShape>(result); 448 } 449 450 StatusOr<Shape> Client::GetShape(const GlobalData& data) { 451 GetShapeRequest request; 452 *request.mutable_data() = data.handle(); 453 GetShapeResponse response; 454 455 VLOG(1) << "making get shape request"; 456 Status s = stub_->GetShape(&request, &response); 457 VLOG(1) << "done with request"; 458 459 if (!s.ok()) { 460 return s; 461 } 462 463 return Shape(response.shape()); 464 } 465 466 StatusOr<string> Client::ExecutionStatsAsString( 467 const XlaComputation& computation, const ExecutionProfile& profile) { 468 TF_ASSIGN_OR_RETURN( 469 auto computation_stats, 470 GetComputationStats(computation, GetDebugOptionsFromFlags())); 471 int64 total_flops = 472 computation_stats.flop_count() + computation_stats.transcendental_count(); 473 if (profile.compute_time_ns() > 0) { 474 int64 nanoseconds = profile.compute_time_ns(); 475 int64 cycle_count = profile.compute_cycle_count(); 476 double gflops = total_flops / nanoseconds; 477 return absl::StrCat( 478 "[Execution Statistics] flop count: ", computation_stats.flop_count(), 479 ", transcendental count: ", computation_stats.transcendental_count(), 480 ", compute execution time: ", nanoseconds, " nsec", 481 ", compute cycles: ", cycle_count, ", performance: ", gflops, 482 "gflop/s"); 483 } 484 return string("[Execution Statistics] not available."); 485 } 486 487 StatusOr<ChannelHandle> Client::CreateChannelHandleByType( 488 ChannelHandle::ChannelType type) { 489 CreateChannelHandleRequest request; 490 request.set_channel_type(type); 491 CreateChannelHandleResponse response; 492 493 VLOG(1) << "making create channel handle request"; 494 Status s = stub_->CreateChannelHandle(&request, &response); 495 VLOG(1) << "done with request"; 496 497 if (!s.ok()) { 498 return s; 499 } 500 501 return response.channel(); 502 } 503 504 StatusOr<ChannelHandle> Client::CreateChannelHandle() { 505 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE); 506 } 507 508 StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() { 509 return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE); 510 } 511 512 StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() { 513 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST); 514 } 515 516 } // namespace xla 517