1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/user_computation.h" 17 18 #include <algorithm> 19 #include <set> 20 #include <stack> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/layout_util.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/compiler/xla/ptr_util.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/service/shape_inference.h" 32 #include "tensorflow/compiler/xla/shape_util.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/types.h" 35 #include "tensorflow/compiler/xla/util.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/strings/str_util.h" 38 #include "tensorflow/core/lib/strings/strcat.h" 39 #include "tensorflow/core/lib/strings/stringprintf.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/protobuf.h" 42 43 namespace xla { 44 namespace { 45 46 HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { 47 switch (unop) { 48 case UNOP_ABS: 49 return HloOpcode::kAbs; 50 case UNOP_CEIL: 51 return HloOpcode::kCeil; 52 case UNOP_COS: 53 return HloOpcode::kCos; 54 case UNOP_EXP: 55 return HloOpcode::kExp; 56 case UNOP_FLOOR: 57 return HloOpcode::kFloor; 58 case UNOP_IMAG: 59 return HloOpcode::kImag; 60 case UNOP_IS_FINITE: 61 return HloOpcode::kIsFinite; 62 case UNOP_LOG: 63 return HloOpcode::kLog; 64 case UNOP_NOT: 65 return HloOpcode::kNot; 66 case UNOP_NEGATE: 67 return HloOpcode::kNegate; 68 case UNOP_REAL: 69 return HloOpcode::kReal; 70 case UNOP_ROUND_NEAREST_AFZ: 71 return HloOpcode::kRoundNearestAfz; 72 case UNOP_SIGN: 73 return HloOpcode::kSign; 74 case UNOP_SIN: 75 return HloOpcode::kSin; 76 case UNOP_SORT: 77 return HloOpcode::kSort; 78 case UNOP_TANH: 79 return HloOpcode::kTanh; 80 default: 81 LOG(FATAL) << "unhandled operation " << unop; 82 } 83 } 84 85 HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { 86 switch (binop) { 87 case BINOP_ATAN2: 88 return HloOpcode::kAtan2; 89 case BINOP_COMPLEX: 90 return HloOpcode::kComplex; 91 case BINOP_MUL: 92 return HloOpcode::kMultiply; 93 case BINOP_ADD: 94 return HloOpcode::kAdd; 95 case BINOP_SUB: 96 return HloOpcode::kSubtract; 97 case BINOP_DIV: 98 return HloOpcode::kDivide; 99 case BINOP_EQ: 100 return HloOpcode::kEq; 101 case BINOP_GE: 102 return HloOpcode::kGe; 103 case BINOP_GT: 104 return HloOpcode::kGt; 105 case BINOP_LE: 106 return HloOpcode::kLe; 107 case BINOP_LT: 108 return HloOpcode::kLt; 109 case BINOP_NE: 110 return HloOpcode::kNe; 111 case BINOP_MAX: 112 return HloOpcode::kMaximum; 113 case BINOP_MIN: 114 return HloOpcode::kMinimum; 115 case BINOP_POW: 116 return HloOpcode::kPower; 117 case BINOP_REM: 118 return HloOpcode::kRemainder; 119 case BINOP_OR: 120 return HloOpcode::kOr; 121 case BINOP_AND: 122 return HloOpcode::kAnd; 123 case BINOP_SHIFT_LEFT: 124 return HloOpcode::kShiftLeft; 125 case BINOP_SHIFT_RIGHT_ARITHMETIC: 126 return HloOpcode::kShiftRightArithmetic; 127 case BINOP_SHIFT_RIGHT_LOGICAL: 128 return HloOpcode::kShiftRightLogical; 129 default: 130 LOG(FATAL) << "unhandled operation " << binop; 131 } 132 } 133 134 HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { 135 switch (triop) { 136 case TRIOP_CLAMP: 137 return HloOpcode::kClamp; 138 case TRIOP_SELECT: 139 return HloOpcode::kSelect; 140 default: 141 LOG(FATAL) << "unhandled operation " << triop; 142 } 143 } 144 145 HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { 146 switch (varop) { 147 case VAROP_TUPLE: 148 return HloOpcode::kTuple; 149 default: 150 LOG(FATAL) << "unhandled operation " << varop; 151 } 152 } 153 154 } // namespace 155 156 /* static */ StatusOr<std::unique_ptr<UserComputation>> 157 UserComputation::MakeWithRemapping( 158 const SessionComputation& session_computation, 159 const ComputationHandle& handle, 160 const std::map<int64, ComputationHandle>& old_to_new) { 161 auto user_computation = 162 MakeUnique<UserComputation>(session_computation.name(), handle); 163 { 164 tensorflow::mutex_lock lock(user_computation->mutex_); 165 user_computation->session_computation_ = session_computation; 166 user_computation->next_handle_value_ = 167 std::max_element(session_computation.requests().begin(), 168 session_computation.requests().end(), 169 [](const std::pair<int64, OperationRequest>& lhs, 170 const std::pair<int64, OperationRequest>& rhs) { 171 return lhs.first < rhs.first; 172 }) 173 ->first + 174 1; 175 TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); 176 } 177 178 return std::move(user_computation); 179 } 180 181 UserComputation::UserComputation(const string& name, 182 const ComputationHandle& handle) 183 : name_(name), next_handle_value_(1) { 184 *session_computation_.mutable_computation_handle() = handle; 185 session_computation_.set_name(name); 186 187 VLOG(1) << "New UserComputation \"" << name 188 << "\", handle: " << handle.handle(); 189 } 190 191 ComputationDataHandle UserComputation::CreateComputationDataHandle() { 192 ComputationDataHandle handle; 193 handle.set_handle(next_handle_value_); 194 // Handles are used as Version values and *must* be assigned consecutively for 195 // computation versioning to work. 196 next_handle_value_++; 197 return handle; 198 } 199 200 StatusOr<ComputationDataHandle> UserComputation::AddParameterInstruction( 201 const ParameterRequest& parameter_request) { 202 tensorflow::mutex_lock lock(mutex_); 203 204 int64 parameter_number = parameter_request.parameter(); 205 if (parameters_.count(parameter_number) != 0) { 206 return InvalidArgument("parameter %lld already registered", 207 parameter_number); 208 } 209 ComputationDataHandle handle = CreateComputationDataHandle(); 210 211 const Shape& validated_shape = parameter_request.shape(); 212 TF_RETURN_IF_ERROR( 213 ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); 214 215 OperationRequest& request = 216 (*session_computation_.mutable_requests())[handle.handle()]; 217 *request.mutable_output_handle() = handle; 218 *request.mutable_output_shape() = validated_shape; 219 *request.mutable_request()->mutable_parameter_request() = parameter_request; 220 221 parameters_[parameter_number] = &request; 222 223 VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() 224 << "), data handle " << handle.handle() << ": " 225 << parameter_request.ShortDebugString(); 226 return handle; 227 } 228 229 Status UserComputation::AddSendInstruction(const SendRequest& send_request) { 230 tensorflow::mutex_lock lock(mutex_); 231 232 // Check if the operand of the instruction is valid. 233 TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); 234 235 // No handle is returned, but a handle must be assigned to this instruction 236 // for computation versioning. 237 ComputationDataHandle handle = CreateComputationDataHandle(); 238 OperationRequest& request = 239 (*session_computation_.mutable_requests())[handle.handle()]; 240 *request.mutable_output_handle() = handle; 241 *request.mutable_output_shape() = ShapeUtil::MakeNil(); 242 *request.mutable_request()->mutable_send_request() = send_request; 243 244 VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() 245 << "), data handle " << handle.handle() << ": " 246 << send_request.ShortDebugString(); 247 return Status::OK(); 248 } 249 250 StatusOr<ComputationDataHandle> UserComputation::AddRecvInstruction( 251 const RecvRequest& recv_request) { 252 tensorflow::mutex_lock lock(mutex_); 253 254 const Shape& shape = recv_request.shape(); 255 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); 256 ComputationDataHandle handle = CreateComputationDataHandle(); 257 OperationRequest& request = 258 (*session_computation_.mutable_requests())[handle.handle()]; 259 *request.mutable_output_handle() = handle; 260 *request.mutable_output_shape() = shape; 261 *request.mutable_request()->mutable_recv_request() = recv_request; 262 263 VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() 264 << "), data handle " << handle.handle() << ": " 265 << recv_request.ShortDebugString(); 266 return handle; 267 } 268 269 StatusOr<ComputationDataHandle> UserComputation::AddPadInstruction( 270 const PadRequest& pad_request) { 271 tensorflow::mutex_lock lock(mutex_); 272 273 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 274 LookUpRequest(pad_request.operand())); 275 276 TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, 277 LookUpRequest(pad_request.padding_value())); 278 279 TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( 280 operand->output_shape(), 281 padding_value->output_shape(), 282 pad_request.padding_config())); 283 284 ComputationDataHandle handle = CreateComputationDataHandle(); 285 OperationRequest& request = 286 (*session_computation_.mutable_requests())[handle.handle()]; 287 *request.mutable_output_handle() = handle; 288 *request.mutable_output_shape() = inferred_shape; 289 *request.mutable_request()->mutable_pad_request() = pad_request; 290 291 VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() 292 << "), data handle " << handle.handle() << ": " 293 << pad_request.ShortDebugString(); 294 return handle; 295 } 296 297 StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction( 298 const ConstantRequest& constant_request) { 299 const Shape& validated_shape = constant_request.literal().shape(); 300 TF_RETURN_IF_ERROR( 301 ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); 302 303 tensorflow::mutex_lock lock(mutex_); 304 305 ComputationDataHandle handle = CreateComputationDataHandle(); 306 307 OperationRequest& request = 308 (*session_computation_.mutable_requests())[handle.handle()]; 309 *request.mutable_output_handle() = handle; 310 *request.mutable_output_shape() = validated_shape; 311 *request.mutable_request()->mutable_constant_request() = constant_request; 312 313 VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() 314 << "), data handle " << handle.handle(); 315 return handle; 316 } 317 318 StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction( 319 const GatherRequest& gather_request) { 320 tensorflow::mutex_lock lock(mutex_); 321 322 TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, 323 LookUpRequest(gather_request.input())); 324 TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, 325 LookUpRequest(gather_request.gather_indices())); 326 327 TF_ASSIGN_OR_RETURN( 328 Shape shape, 329 ShapeInference::InferGatherShape( 330 input_request->output_shape(), gather_indices_request->output_shape(), 331 gather_request.dimension_numbers(), 332 AsInt64Slice(gather_request.window_bounds()))); 333 334 const ComputationDataHandle handle = CreateComputationDataHandle(); 335 336 OperationRequest& request = 337 (*session_computation_.mutable_requests())[handle.handle()]; 338 *request.mutable_output_handle() = handle; 339 *request.mutable_output_shape() = shape; 340 *request.mutable_request()->mutable_gather_request() = gather_request; 341 342 VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() 343 << "), data handle " << handle.handle() << ": " 344 << gather_request.ShortDebugString(); 345 return handle; 346 } 347 348 StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction( 349 const GetTupleElementRequest& get_tuple_element_request) { 350 tensorflow::mutex_lock lock(mutex_); 351 352 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 353 LookUpRequest(get_tuple_element_request.operand())); 354 if (!ShapeUtil::IsTuple(operand->output_shape())) { 355 return InvalidArgument( 356 "Operand to GetTupleElement() is not a tuple; got %s", 357 ShapeUtil::HumanString(operand->output_shape()).c_str()); 358 } 359 Shape element_shape = ShapeUtil::GetTupleElementShape( 360 operand->output_shape(), get_tuple_element_request.index()); 361 362 ComputationDataHandle handle = CreateComputationDataHandle(); 363 364 OperationRequest& request = 365 (*session_computation_.mutable_requests())[handle.handle()]; 366 *request.mutable_output_handle() = handle; 367 *request.mutable_output_shape() = element_shape; 368 *request.mutable_request()->mutable_get_tuple_element_request() = 369 get_tuple_element_request; 370 371 VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() 372 << "), data handle " << handle.handle() << ": " 373 << get_tuple_element_request.ShortDebugString(); 374 return handle; 375 } 376 377 Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { 378 tensorflow::mutex_lock lock(mutex_); 379 380 // Verify that the operand index is valid. 381 TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); 382 383 ComputationDataHandle handle = CreateComputationDataHandle(); 384 OperationRequest& request = 385 (*session_computation_.mutable_requests())[handle.handle()]; 386 *request.mutable_output_handle() = handle; 387 *request.mutable_output_shape() = ShapeUtil::MakeNil(); 388 *request.mutable_request()->mutable_trace_request() = trace_request; 389 390 VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() 391 << "), data handle " << handle.handle() << ": " 392 << trace_request.ShortDebugString(); 393 return Status::OK(); 394 } 395 396 StatusOr<ComputationDataHandle> UserComputation::AddRngInstruction( 397 const RngRequest& rng_request) { 398 tensorflow::mutex_lock lock(mutex_); 399 400 // Check the number of parameters per RNG distribution. 401 switch (rng_request.distribution()) { 402 case RandomDistribution::RNG_NORMAL: 403 case RandomDistribution::RNG_UNIFORM: 404 if (rng_request.parameter_size() != 2) { 405 return InvalidArgument( 406 "RNG distribution (%s) expects 2 parameters, but got %d", 407 RandomDistribution_Name(rng_request.distribution()).c_str(), 408 rng_request.parameter_size()); 409 } 410 break; 411 default: 412 LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); 413 } 414 415 // Verify that the parameter indices are valid; 416 for (const ComputationDataHandle& param : rng_request.parameter()) { 417 TF_RETURN_IF_ERROR(LookUpRequest(param).status()); 418 } 419 const Shape& validated_shape = rng_request.shape(); 420 TF_RETURN_IF_ERROR( 421 ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); 422 423 ComputationDataHandle handle = CreateComputationDataHandle(); 424 425 OperationRequest& request = 426 (*session_computation_.mutable_requests())[handle.handle()]; 427 *request.mutable_output_handle() = handle; 428 *request.mutable_output_shape() = validated_shape; 429 *request.mutable_request()->mutable_rng_request() = rng_request; 430 431 VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() 432 << "), data handle " << handle.handle() << ": " 433 << rng_request.ShortDebugString(); 434 return handle; 435 } 436 437 StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction( 438 const MapRequest& map_request, 439 const UserComputation& to_apply_computation) { 440 tensorflow::mutex_lock lock(mutex_); 441 442 std::vector<const Shape*> operand_shapes; 443 for (const ComputationDataHandle& handle : map_request.operands()) { 444 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); 445 operand_shapes.push_back(&operand->output_shape()); 446 } 447 448 VersionedComputationHandle::Version to_apply_version = 449 to_apply_computation.version(); 450 TF_ASSIGN_OR_RETURN( 451 std::shared_ptr<const ProgramShape> to_apply_program_shape, 452 to_apply_computation.ComputeProgramShape(to_apply_version)); 453 TF_ASSIGN_OR_RETURN( 454 Shape inferred_shape, 455 ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, 456 AsInt64Slice(map_request.dimensions()))); 457 458 ComputationDataHandle handle = CreateComputationDataHandle(); 459 460 OperationRequest& request = 461 (*session_computation_.mutable_requests())[handle.handle()]; 462 *request.mutable_output_handle() = handle; 463 *request.mutable_output_shape() = inferred_shape; 464 request.add_embedded_computation_versions(to_apply_version); 465 *request.mutable_request()->mutable_map_request() = map_request; 466 467 VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() 468 << "), data handle " << handle.handle() << ": " 469 << map_request.ShortDebugString(); 470 return handle; 471 } 472 473 StatusOr<ComputationDataHandle> UserComputation::AddReduceInstruction( 474 const ReduceRequest& reduce_request, 475 const UserComputation& to_apply_computation) { 476 tensorflow::mutex_lock lock(mutex_); 477 478 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 479 LookUpRequest(reduce_request.operand())); 480 TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, 481 LookUpRequest(reduce_request.init_value())); 482 483 VersionedComputationHandle::Version to_apply_version = 484 to_apply_computation.version(); 485 TF_ASSIGN_OR_RETURN( 486 std::shared_ptr<const ProgramShape> to_apply_program_shape, 487 to_apply_computation.ComputeProgramShape(to_apply_version)); 488 489 TF_ASSIGN_OR_RETURN( 490 Shape inferred_shape, 491 ShapeInference::InferReduceShape( 492 operand->output_shape(), init_value->output_shape(), 493 AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); 494 495 ComputationDataHandle handle = CreateComputationDataHandle(); 496 497 OperationRequest& request = 498 (*session_computation_.mutable_requests())[handle.handle()]; 499 *request.mutable_output_handle() = handle; 500 *request.mutable_output_shape() = inferred_shape; 501 request.add_embedded_computation_versions(to_apply_version); 502 *request.mutable_request()->mutable_reduce_request() = reduce_request; 503 504 VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() 505 << "), data handle " << handle.handle() << ": " 506 << reduce_request.ShortDebugString(); 507 return handle; 508 } 509 510 StatusOr<ComputationDataHandle> 511 UserComputation::AddBatchNormTrainingInstruction( 512 const BatchNormTrainingRequest& batch_norm_training_request) { 513 tensorflow::mutex_lock lock(mutex_); 514 515 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 516 LookUpRequest(batch_norm_training_request.operand())); 517 518 TF_ASSIGN_OR_RETURN(const OperationRequest* scale, 519 LookUpRequest(batch_norm_training_request.scale())); 520 521 TF_ASSIGN_OR_RETURN(const OperationRequest* offset, 522 LookUpRequest(batch_norm_training_request.offset())); 523 524 ComputationDataHandle handle = CreateComputationDataHandle(); 525 526 OperationRequest& request = 527 (*session_computation_.mutable_requests())[handle.handle()]; 528 529 TF_ASSIGN_OR_RETURN( 530 Shape inferred_shape, 531 ShapeInference::InferBatchNormTrainingShape( 532 operand->output_shape(), scale->output_shape(), 533 offset->output_shape(), batch_norm_training_request.feature_index())); 534 535 *request.mutable_output_shape() = inferred_shape; 536 537 *request.mutable_output_handle() = handle; 538 539 *request.mutable_request()->mutable_batch_norm_training_request() = 540 batch_norm_training_request; 541 542 VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() 543 << "), data handle " << handle.handle() << ": " 544 << batch_norm_training_request.ShortDebugString(); 545 546 return handle; 547 } 548 549 StatusOr<ComputationDataHandle> 550 UserComputation::AddBatchNormInferenceInstruction( 551 const BatchNormInferenceRequest& batch_norm_inference_request) { 552 tensorflow::mutex_lock lock(mutex_); 553 554 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 555 LookUpRequest(batch_norm_inference_request.operand())); 556 557 TF_ASSIGN_OR_RETURN(const OperationRequest* scale, 558 LookUpRequest(batch_norm_inference_request.scale())); 559 560 TF_ASSIGN_OR_RETURN(const OperationRequest* offset, 561 LookUpRequest(batch_norm_inference_request.offset())); 562 563 TF_ASSIGN_OR_RETURN(const OperationRequest* mean, 564 LookUpRequest(batch_norm_inference_request.mean())); 565 566 TF_ASSIGN_OR_RETURN(const OperationRequest* variance, 567 LookUpRequest(batch_norm_inference_request.variance())); 568 569 ComputationDataHandle handle = CreateComputationDataHandle(); 570 571 OperationRequest& request = 572 (*session_computation_.mutable_requests())[handle.handle()]; 573 574 TF_ASSIGN_OR_RETURN(Shape inferred_shape, 575 ShapeInference::InferBatchNormInferenceShape( 576 operand->output_shape(), scale->output_shape(), 577 offset->output_shape(), mean->output_shape(), 578 variance->output_shape(), 579 batch_norm_inference_request.feature_index())); 580 581 *request.mutable_output_shape() = inferred_shape; 582 583 *request.mutable_output_handle() = handle; 584 585 *request.mutable_request()->mutable_batch_norm_inference_request() = 586 batch_norm_inference_request; 587 588 VLOG(1) << "AddBatchNormInferenceInstruction (" 589 << GetVersionedHandleInternal() << "), data handle " 590 << handle.handle() << ": " 591 << batch_norm_inference_request.ShortDebugString(); 592 593 return handle; 594 } 595 596 StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction( 597 const BatchNormGradRequest& batch_norm_grad_request) { 598 tensorflow::mutex_lock lock(mutex_); 599 600 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 601 LookUpRequest(batch_norm_grad_request.operand())); 602 603 TF_ASSIGN_OR_RETURN(const OperationRequest* scale, 604 LookUpRequest(batch_norm_grad_request.scale())); 605 606 TF_ASSIGN_OR_RETURN(const OperationRequest* mean, 607 LookUpRequest(batch_norm_grad_request.mean())); 608 609 TF_ASSIGN_OR_RETURN(const OperationRequest* variance, 610 LookUpRequest(batch_norm_grad_request.variance())); 611 612 TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, 613 LookUpRequest(batch_norm_grad_request.grad_output())); 614 615 ComputationDataHandle handle = CreateComputationDataHandle(); 616 617 OperationRequest& request = 618 (*session_computation_.mutable_requests())[handle.handle()]; 619 620 TF_ASSIGN_OR_RETURN( 621 Shape inferred_shape, 622 ShapeInference::InferBatchNormGradShape( 623 operand->output_shape(), scale->output_shape(), mean->output_shape(), 624 variance->output_shape(), grad_output->output_shape(), 625 batch_norm_grad_request.feature_index())); 626 627 *request.mutable_output_shape() = inferred_shape; 628 629 *request.mutable_output_handle() = handle; 630 631 *request.mutable_request()->mutable_batch_norm_grad_request() = 632 batch_norm_grad_request; 633 634 VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() 635 << "), data handle " << handle.handle() << ": " 636 << batch_norm_grad_request.ShortDebugString(); 637 638 return handle; 639 } 640 641 StatusOr<ComputationDataHandle> UserComputation::AddReduceWindowInstruction( 642 const ReduceWindowRequest& reduce_window_request, 643 const UserComputation& to_apply_computation) { 644 tensorflow::mutex_lock lock(mutex_); 645 646 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 647 LookUpRequest(reduce_window_request.operand())); 648 TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, 649 LookUpRequest(reduce_window_request.init_value())); 650 651 VersionedComputationHandle::Version to_apply_version = 652 to_apply_computation.version(); 653 TF_ASSIGN_OR_RETURN( 654 std::shared_ptr<const ProgramShape> to_apply_program_shape, 655 to_apply_computation.ComputeProgramShape(to_apply_version)); 656 657 TF_ASSIGN_OR_RETURN( 658 Shape inferred_shape, 659 ShapeInference::InferReduceWindowShape( 660 operand->output_shape(), init_value->output_shape(), 661 reduce_window_request.window(), *to_apply_program_shape)); 662 663 ComputationDataHandle handle = CreateComputationDataHandle(); 664 665 OperationRequest& request = 666 (*session_computation_.mutable_requests())[handle.handle()]; 667 *request.mutable_output_handle() = handle; 668 *request.mutable_output_shape() = inferred_shape; 669 request.add_embedded_computation_versions(to_apply_version); 670 *request.mutable_request()->mutable_reduce_window_request() = 671 reduce_window_request; 672 673 VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() 674 << "), data handle " << handle.handle() << ": " 675 << reduce_window_request.ShortDebugString(); 676 return handle; 677 } 678 679 StatusOr<ComputationDataHandle> UserComputation::AddSelectAndScatterInstruction( 680 const SelectAndScatterRequest& select_and_scatter_request, 681 const UserComputation& select_computation, 682 const UserComputation& scatter_computation) { 683 tensorflow::mutex_lock lock(mutex_); 684 685 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 686 LookUpRequest(select_and_scatter_request.operand())); 687 TF_ASSIGN_OR_RETURN(const OperationRequest* source, 688 LookUpRequest(select_and_scatter_request.source())); 689 TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, 690 LookUpRequest(select_and_scatter_request.init_value())); 691 692 VersionedComputationHandle::Version select_version = 693 select_computation.version(); 694 TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> select_program_shape, 695 select_computation.ComputeProgramShape(select_version)); 696 VersionedComputationHandle::Version scatter_version = 697 scatter_computation.version(); 698 TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> scatter_program_shape, 699 scatter_computation.ComputeProgramShape(scatter_version)); 700 701 TF_ASSIGN_OR_RETURN( 702 Shape inferred_shape, 703 ShapeInference::InferSelectAndScatterShape( 704 operand->output_shape(), *select_program_shape, 705 select_and_scatter_request.window(), source->output_shape(), 706 init_value->output_shape(), *scatter_program_shape)); 707 708 ComputationDataHandle handle = CreateComputationDataHandle(); 709 710 OperationRequest& request = 711 (*session_computation_.mutable_requests())[handle.handle()]; 712 *request.mutable_output_handle() = handle; 713 *request.mutable_output_shape() = inferred_shape; 714 request.add_embedded_computation_versions(select_version); 715 request.add_embedded_computation_versions(scatter_version); 716 *request.mutable_request()->mutable_select_and_scatter_request() = 717 select_and_scatter_request; 718 719 VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() 720 << "), data handle " << handle.handle() << ": " 721 << select_and_scatter_request.ShortDebugString(); 722 return handle; 723 } 724 725 StatusOr<ComputationDataHandle> UserComputation::AddReverseInstruction( 726 const ReverseRequest& reverse_request) { 727 tensorflow::mutex_lock lock(mutex_); 728 729 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 730 LookUpRequest(reverse_request.operand())); 731 TF_ASSIGN_OR_RETURN( 732 Shape inferred_shape, 733 ShapeInference::InferReverseShape( 734 operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); 735 736 ComputationDataHandle handle = CreateComputationDataHandle(); 737 OperationRequest& request = 738 (*session_computation_.mutable_requests())[handle.handle()]; 739 *request.mutable_output_handle() = handle; 740 *request.mutable_output_shape() = inferred_shape; 741 *request.mutable_request()->mutable_reverse_request() = reverse_request; 742 VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() 743 << "), data handle " << handle.handle() << ": " 744 << reverse_request.ShortDebugString(); 745 return handle; 746 } 747 748 StatusOr<ComputationDataHandle> UserComputation::AddWhileInstruction( 749 const WhileRequest& while_request, 750 const UserComputation& condition_computation, 751 const UserComputation& body_computation) { 752 tensorflow::mutex_lock lock(mutex_); 753 754 TF_ASSIGN_OR_RETURN(const OperationRequest* init, 755 LookUpRequest(while_request.init())); 756 757 VersionedComputationHandle::Version condition_version = 758 condition_computation.version(); 759 TF_ASSIGN_OR_RETURN( 760 std::shared_ptr<const ProgramShape> condition_program_shape, 761 condition_computation.ComputeProgramShape(condition_version)); 762 763 VersionedComputationHandle::Version body_version = body_computation.version(); 764 TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> body_program_shape, 765 body_computation.ComputeProgramShape(body_version)); 766 767 TF_ASSIGN_OR_RETURN( 768 Shape inferred_shape, 769 ShapeInference::InferWhileShape( 770 *condition_program_shape, *body_program_shape, init->output_shape())); 771 772 ComputationDataHandle handle = CreateComputationDataHandle(); 773 774 OperationRequest& request = 775 (*session_computation_.mutable_requests())[handle.handle()]; 776 *request.mutable_output_handle() = handle; 777 *request.mutable_output_shape() = inferred_shape; 778 request.add_embedded_computation_versions(condition_version); 779 request.add_embedded_computation_versions(body_version); 780 *request.mutable_request()->mutable_while_request() = while_request; 781 782 VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() 783 << "), data handle " << handle.handle() << ": " 784 << while_request.ShortDebugString(); 785 return handle; 786 } 787 788 StatusOr<ComputationDataHandle> UserComputation::AddConditionalInstruction( 789 const ConditionalRequest& conditional_request, 790 const UserComputation& true_computation, 791 const UserComputation& false_computation) { 792 tensorflow::mutex_lock lock(mutex_); 793 794 TF_ASSIGN_OR_RETURN(const OperationRequest* pred, 795 LookUpRequest(conditional_request.predicate())); 796 TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, 797 LookUpRequest(conditional_request.true_operand())); 798 TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, 799 LookUpRequest(conditional_request.false_operand())); 800 801 VersionedComputationHandle::Version true_computation_version = 802 true_computation.version(); 803 TF_ASSIGN_OR_RETURN( 804 std::shared_ptr<const ProgramShape> true_computation_shape, 805 true_computation.ComputeProgramShape(true_computation_version)); 806 807 VersionedComputationHandle::Version false_computation_version = 808 false_computation.version(); 809 TF_ASSIGN_OR_RETURN( 810 std::shared_ptr<const ProgramShape> false_computation_shape, 811 false_computation.ComputeProgramShape(false_computation_version)); 812 813 TF_ASSIGN_OR_RETURN(Shape inferred_shape, 814 ShapeInference::InferConditionalShape( 815 pred->output_shape(), true_operand->output_shape(), 816 false_operand->output_shape(), 817 *true_computation_shape, *false_computation_shape)); 818 819 ComputationDataHandle handle = CreateComputationDataHandle(); 820 821 OperationRequest& request = 822 (*session_computation_.mutable_requests())[handle.handle()]; 823 *request.mutable_output_handle() = handle; 824 *request.mutable_output_shape() = inferred_shape; 825 request.add_embedded_computation_versions(true_computation_version); 826 request.add_embedded_computation_versions(false_computation_version); 827 *request.mutable_request()->mutable_conditional_request() = 828 conditional_request; 829 830 VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() 831 << "), data handle " << handle.handle() << ": " 832 << conditional_request.ShortDebugString(); 833 return handle; 834 } 835 836 StatusOr<ComputationDataHandle> UserComputation::AddBroadcastInstruction( 837 const BroadcastRequest& broadcast_request) { 838 tensorflow::mutex_lock lock(mutex_); 839 840 // Fetches and validates the operand. 841 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 842 LookUpRequest(broadcast_request.operand())); 843 TF_ASSIGN_OR_RETURN(Shape inferred_shape, 844 ShapeInference::InferBroadcastShape( 845 operand->output_shape(), 846 AsInt64Slice(broadcast_request.broadcast_sizes()))); 847 848 ComputationDataHandle handle = CreateComputationDataHandle(); 849 OperationRequest& request = 850 (*session_computation_.mutable_requests())[handle.handle()]; 851 *request.mutable_output_handle() = handle; 852 *request.mutable_output_shape() = inferred_shape; 853 *request.mutable_request()->mutable_broadcast_request() = broadcast_request; 854 855 VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() 856 << "), data handle " << handle.handle() << ": " 857 << broadcast_request.ShortDebugString(); 858 return handle; 859 } 860 861 StatusOr<ComputationDataHandle> UserComputation::AddReshapeInstruction( 862 const ReshapeRequest& reshape_request) { 863 tensorflow::mutex_lock lock(mutex_); 864 865 // Fetches and validates the operand. 866 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 867 LookUpRequest(reshape_request.operand())); 868 869 TF_ASSIGN_OR_RETURN( 870 Shape inferred_shape, 871 ShapeInference::InferReshapeShape( 872 operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), 873 AsInt64Slice(reshape_request.new_sizes()))); 874 875 ComputationDataHandle handle = CreateComputationDataHandle(); 876 877 OperationRequest& request = 878 (*session_computation_.mutable_requests())[handle.handle()]; 879 *request.mutable_output_handle() = handle; 880 *request.mutable_output_shape() = inferred_shape; 881 *request.mutable_request()->mutable_reshape_request() = reshape_request; 882 883 VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() 884 << "), data handle " << handle.handle() << ": " 885 << reshape_request.ShortDebugString(); 886 return handle; 887 } 888 889 StatusOr<ComputationDataHandle> UserComputation::AddTransposeInstruction( 890 const TransposeRequest& transpose_request) { 891 tensorflow::mutex_lock lock(mutex_); 892 893 // Fetches and validates the operand. 894 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 895 LookUpRequest(transpose_request.operand())); 896 897 TF_ASSIGN_OR_RETURN(Shape inferred_shape, 898 ShapeInference::InferTransposeShape( 899 operand->output_shape(), 900 AsInt64Slice(transpose_request.dimensions()))); 901 902 ComputationDataHandle handle = CreateComputationDataHandle(); 903 904 OperationRequest& request = 905 (*session_computation_.mutable_requests())[handle.handle()]; 906 *request.mutable_output_handle() = handle; 907 *request.mutable_output_shape() = inferred_shape; 908 *request.mutable_request()->mutable_transpose_request() = transpose_request; 909 910 VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() 911 << "), data handle " << handle.handle() << ": " 912 << transpose_request.ShortDebugString(); 913 return handle; 914 } 915 916 StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction( 917 const SliceRequest& slice_request) { 918 tensorflow::mutex_lock lock(mutex_); 919 920 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 921 LookUpRequest(slice_request.operand())); 922 923 TF_ASSIGN_OR_RETURN( 924 Shape new_shape, 925 ShapeInference::InferSliceShape( 926 operand->output_shape(), AsInt64Slice(slice_request.start_indices()), 927 AsInt64Slice(slice_request.limit_indices()), 928 AsInt64Slice(slice_request.strides()))); 929 930 ComputationDataHandle handle = CreateComputationDataHandle(); 931 932 OperationRequest& request = 933 (*session_computation_.mutable_requests())[handle.handle()]; 934 *request.mutable_output_handle() = handle; 935 *request.mutable_output_shape() = new_shape; 936 *request.mutable_request()->mutable_slice_request() = slice_request; 937 938 VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() 939 << "), data handle " << handle.handle() << ": " 940 << slice_request.ShortDebugString(); 941 return handle; 942 } 943 944 StatusOr<ComputationDataHandle> UserComputation::AddDynamicSliceInstruction( 945 const DynamicSliceRequest& dynamic_slice_request) { 946 tensorflow::mutex_lock lock(mutex_); 947 948 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 949 LookUpRequest(dynamic_slice_request.operand())); 950 951 TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, 952 LookUpRequest(dynamic_slice_request.start_indices())); 953 954 TF_ASSIGN_OR_RETURN( 955 Shape new_shape, 956 ShapeInference::InferDynamicSliceShape( 957 operand->output_shape(), start_indices->output_shape(), 958 AsInt64Slice(dynamic_slice_request.slice_sizes()))); 959 960 ComputationDataHandle handle = CreateComputationDataHandle(); 961 962 OperationRequest& request = 963 (*session_computation_.mutable_requests())[handle.handle()]; 964 *request.mutable_output_handle() = handle; 965 *request.mutable_output_shape() = new_shape; 966 *request.mutable_request()->mutable_dynamic_slice_request() = 967 dynamic_slice_request; 968 969 VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() 970 << "), data handle " << handle.handle() << ": " 971 << dynamic_slice_request.ShortDebugString(); 972 return handle; 973 } 974 975 StatusOr<ComputationDataHandle> 976 UserComputation::AddDynamicUpdateSliceInstruction( 977 const DynamicUpdateSliceRequest& dynamic_update_slice_request) { 978 tensorflow::mutex_lock lock(mutex_); 979 980 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 981 LookUpRequest(dynamic_update_slice_request.operand())); 982 983 TF_ASSIGN_OR_RETURN(const OperationRequest* update, 984 LookUpRequest(dynamic_update_slice_request.update())); 985 986 TF_ASSIGN_OR_RETURN( 987 const OperationRequest* start_indices, 988 LookUpRequest(dynamic_update_slice_request.start_indices())); 989 990 TF_ASSIGN_OR_RETURN(Shape new_shape, 991 ShapeInference::InferDynamicUpdateSliceShape( 992 operand->output_shape(), update->output_shape(), 993 start_indices->output_shape())); 994 995 ComputationDataHandle handle = CreateComputationDataHandle(); 996 997 OperationRequest& request = 998 (*session_computation_.mutable_requests())[handle.handle()]; 999 *request.mutable_output_handle() = handle; 1000 *request.mutable_output_shape() = new_shape; 1001 *request.mutable_request()->mutable_dynamic_update_slice_request() = 1002 dynamic_update_slice_request; 1003 1004 VLOG(1) << "AddDynamicUpdateSliceInstruction (" 1005 << GetVersionedHandleInternal() << "), data handle " 1006 << handle.handle() << ": " 1007 << dynamic_update_slice_request.ShortDebugString(); 1008 return handle; 1009 } 1010 1011 StatusOr<ComputationDataHandle> UserComputation::AddConcatenateInstruction( 1012 const ConcatenateRequest& concatenate_request) { 1013 tensorflow::mutex_lock lock(mutex_); 1014 1015 std::vector<const Shape*> operand_shapes; 1016 for (const ComputationDataHandle& handle : concatenate_request.operands()) { 1017 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); 1018 operand_shapes.push_back(&operand->output_shape()); 1019 } 1020 1021 TF_ASSIGN_OR_RETURN(Shape new_shape, 1022 ShapeInference::InferConcatOpShape( 1023 operand_shapes, concatenate_request.dimension())); 1024 1025 ComputationDataHandle handle = CreateComputationDataHandle(); 1026 1027 OperationRequest& request = 1028 (*session_computation_.mutable_requests())[handle.handle()]; 1029 *request.mutable_output_handle() = handle; 1030 *request.mutable_output_shape() = new_shape; 1031 *request.mutable_request()->mutable_concatenate_request() = 1032 concatenate_request; 1033 1034 VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() 1035 << "), data handle " << handle.handle() << ": " 1036 << concatenate_request.ShortDebugString(); 1037 return handle; 1038 } 1039 1040 StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction( 1041 const ConvertRequest& convert_request) { 1042 tensorflow::mutex_lock lock(mutex_); 1043 1044 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1045 LookUpRequest(convert_request.operand())); 1046 1047 TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( 1048 operand->output_shape(), 1049 convert_request.new_element_type())); 1050 1051 ComputationDataHandle handle = CreateComputationDataHandle(); 1052 1053 OperationRequest& request = 1054 (*session_computation_.mutable_requests())[handle.handle()]; 1055 *request.mutable_output_handle() = handle; 1056 *request.mutable_output_shape() = new_shape; 1057 *request.mutable_request()->mutable_convert_request() = convert_request; 1058 1059 VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() 1060 << "), data handle " << handle.handle() << ": " 1061 << convert_request.ShortDebugString(); 1062 return handle; 1063 } 1064 1065 StatusOr<ComputationDataHandle> UserComputation::AddBitcastConvertInstruction( 1066 const ConvertRequest& convert_request) { 1067 tensorflow::mutex_lock lock(mutex_); 1068 1069 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1070 LookUpRequest(convert_request.operand())); 1071 1072 TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( 1073 operand->output_shape(), 1074 convert_request.new_element_type())); 1075 1076 ComputationDataHandle handle = CreateComputationDataHandle(); 1077 1078 OperationRequest& request = 1079 (*session_computation_.mutable_requests())[handle.handle()]; 1080 *request.mutable_output_handle() = handle; 1081 *request.mutable_output_shape() = new_shape; 1082 *request.mutable_request()->mutable_bitcast_convert_request() = 1083 convert_request; 1084 1085 VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() 1086 << "), data handle " << handle.handle() << ": " 1087 << convert_request.ShortDebugString(); 1088 return handle; 1089 } 1090 1091 StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction( 1092 const ReducePrecisionRequest& reduce_precision_request) { 1093 tensorflow::mutex_lock lock(mutex_); 1094 1095 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1096 LookUpRequest(reduce_precision_request.operand())); 1097 1098 TF_ASSIGN_OR_RETURN( 1099 Shape new_shape, 1100 ShapeInference::InferReducePrecisionShape( 1101 operand->output_shape(), reduce_precision_request.exponent_bits(), 1102 reduce_precision_request.mantissa_bits())); 1103 1104 ComputationDataHandle handle = CreateComputationDataHandle(); 1105 1106 OperationRequest& request = 1107 (*session_computation_.mutable_requests())[handle.handle()]; 1108 *request.mutable_output_handle() = handle; 1109 *request.mutable_output_shape() = new_shape; 1110 *request.mutable_request()->mutable_reduce_precision_request() = 1111 reduce_precision_request; 1112 1113 VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() 1114 << "), data handle " << handle.handle() << ": " 1115 << reduce_precision_request.ShortDebugString(); 1116 return handle; 1117 } 1118 1119 StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction( 1120 const ConvolveRequest& convolve_request) { 1121 tensorflow::mutex_lock lock(mutex_); 1122 1123 TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, 1124 LookUpRequest(convolve_request.lhs())); 1125 TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, 1126 LookUpRequest(convolve_request.rhs())); 1127 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( 1128 lhs->output_shape(), rhs->output_shape(), 1129 convolve_request.window(), 1130 convolve_request.dimension_numbers())); 1131 1132 const ComputationDataHandle handle = CreateComputationDataHandle(); 1133 1134 OperationRequest& request = 1135 (*session_computation_.mutable_requests())[handle.handle()]; 1136 *request.mutable_output_handle() = handle; 1137 *request.mutable_output_shape() = shape; 1138 *request.mutable_request()->mutable_convolve_request() = convolve_request; 1139 1140 VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() 1141 << "), data handle " << handle.handle() << ": " 1142 << convolve_request.ShortDebugString(); 1143 return handle; 1144 } 1145 1146 StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction( 1147 const FftRequest& fft_request) { 1148 tensorflow::mutex_lock lock(mutex_); 1149 1150 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1151 LookUpRequest(fft_request.operand())); 1152 TF_ASSIGN_OR_RETURN(Shape shape, 1153 ShapeInference::InferFftShape( 1154 operand->output_shape(), fft_request.fft_type(), 1155 AsInt64Slice(fft_request.fft_length()))); 1156 1157 const ComputationDataHandle handle = CreateComputationDataHandle(); 1158 1159 OperationRequest& request = 1160 (*session_computation_.mutable_requests())[handle.handle()]; 1161 *request.mutable_output_handle() = handle; 1162 *request.mutable_output_shape() = shape; 1163 *request.mutable_request()->mutable_fft_request() = fft_request; 1164 1165 VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() 1166 << "), data handle " << handle.handle() << ": " 1167 << fft_request.ShortDebugString(); 1168 return handle; 1169 } 1170 1171 StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction( 1172 const CrossReplicaSumRequest& cross_replica_sum_request) { 1173 tensorflow::mutex_lock lock(mutex_); 1174 1175 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1176 LookUpRequest(cross_replica_sum_request.operand())); 1177 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( 1178 {&operand->output_shape()})); 1179 1180 ComputationDataHandle handle = CreateComputationDataHandle(); 1181 1182 OperationRequest& request = 1183 (*session_computation_.mutable_requests())[handle.handle()]; 1184 *request.mutable_output_handle() = handle; 1185 *request.mutable_output_shape() = shape; 1186 *request.mutable_request()->mutable_cross_replica_sum_request() = 1187 cross_replica_sum_request; 1188 1189 VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() 1190 << "), data handle " << handle.handle() << ": " 1191 << cross_replica_sum_request.ShortDebugString(); 1192 return handle; 1193 } 1194 1195 StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction( 1196 const InfeedRequest& infeed_request) { 1197 tensorflow::mutex_lock lock(mutex_); 1198 1199 const Shape& shape = infeed_request.shape(); 1200 if (!LayoutUtil::HasLayout(shape)) { 1201 return InvalidArgument("Given shape to Infeed must have a layout"); 1202 } 1203 1204 const ComputationDataHandle handle = CreateComputationDataHandle(); 1205 1206 OperationRequest& request = 1207 (*session_computation_.mutable_requests())[handle.handle()]; 1208 *request.mutable_output_handle() = handle; 1209 *request.mutable_output_shape() = shape; 1210 *request.mutable_request()->mutable_infeed_request() = infeed_request; 1211 1212 VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() 1213 << "), data handle " << handle.handle() << ": " 1214 << infeed_request.ShortDebugString(); 1215 return handle; 1216 } 1217 1218 StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction( 1219 const OutfeedRequest& outfeed_request) { 1220 tensorflow::mutex_lock lock(mutex_); 1221 1222 const Shape& shape = outfeed_request.shape(); 1223 if (!LayoutUtil::HasLayout(shape)) { 1224 return InvalidArgument("Given shape to Outfeed must have a layout"); 1225 } 1226 1227 // Verify that operand is valid. 1228 TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); 1229 1230 ComputationDataHandle handle = CreateComputationDataHandle(); 1231 OperationRequest& request = 1232 (*session_computation_.mutable_requests())[handle.handle()]; 1233 *request.mutable_output_handle() = handle; 1234 *request.mutable_output_shape() = shape; 1235 *request.mutable_request()->mutable_outfeed_request() = outfeed_request; 1236 1237 VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() 1238 << "), data handle " << handle.handle() << ": " 1239 << outfeed_request.ShortDebugString(); 1240 return handle; 1241 } 1242 1243 StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction( 1244 const CallRequest& call_request, 1245 const UserComputation& to_apply_computation) { 1246 tensorflow::mutex_lock lock(mutex_); 1247 1248 std::vector<const Shape*> operand_shapes; 1249 for (const ComputationDataHandle& handle : call_request.operands()) { 1250 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); 1251 operand_shapes.push_back(&operand->output_shape()); 1252 } 1253 1254 VersionedComputationHandle::Version to_apply_version = 1255 to_apply_computation.version(); 1256 TF_ASSIGN_OR_RETURN( 1257 std::shared_ptr<const ProgramShape> to_apply_program_shape, 1258 to_apply_computation.ComputeProgramShape(to_apply_version)); 1259 TF_ASSIGN_OR_RETURN( 1260 Shape inferred_shape, 1261 ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); 1262 1263 ComputationDataHandle handle = CreateComputationDataHandle(); 1264 1265 OperationRequest& request = 1266 (*session_computation_.mutable_requests())[handle.handle()]; 1267 *request.mutable_output_handle() = handle; 1268 *request.mutable_output_shape() = inferred_shape; 1269 request.add_embedded_computation_versions(to_apply_version); 1270 *request.mutable_request()->mutable_call_request() = call_request; 1271 1272 VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() 1273 << "), data handle " << handle.handle() << ": " 1274 << call_request.ShortDebugString(); 1275 return handle; 1276 } 1277 1278 StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction( 1279 const CustomCallRequest& custom_call_request) { 1280 tensorflow::mutex_lock lock(mutex_); 1281 1282 for (const ComputationDataHandle& handle : custom_call_request.operands()) { 1283 TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); 1284 } 1285 1286 if (tensorflow::StringPiece(custom_call_request.call_target_name()) 1287 .starts_with("$")) { 1288 return InvalidArgument( 1289 "Invalid custom_call_target \"%s\": Call targets that start with '$' " 1290 "are reserved for internal use.", 1291 custom_call_request.call_target_name().c_str()); 1292 } 1293 1294 const ComputationDataHandle handle = CreateComputationDataHandle(); 1295 1296 OperationRequest& request = 1297 (*session_computation_.mutable_requests())[handle.handle()]; 1298 *request.mutable_output_handle() = handle; 1299 *request.mutable_output_shape() = custom_call_request.shape(); 1300 *request.mutable_request()->mutable_custom_call_request() = 1301 custom_call_request; 1302 1303 VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() 1304 << "), data handle " << handle.handle() << ": " 1305 << custom_call_request.ShortDebugString(); 1306 return handle; 1307 } 1308 1309 StatusOr<ComputationDataHandle> UserComputation::AddHostComputeInstruction( 1310 const HostComputeRequest& host_compute_request) { 1311 tensorflow::mutex_lock lock(mutex_); 1312 1313 for (const ComputationDataHandle& handle : host_compute_request.operands()) { 1314 TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); 1315 } 1316 1317 ComputationDataHandle handle = CreateComputationDataHandle(); 1318 OperationRequest& request = 1319 (*session_computation_.mutable_requests())[handle.handle()]; 1320 *request.mutable_output_handle() = handle; 1321 *request.mutable_output_shape() = host_compute_request.shape(); 1322 *request.mutable_request()->mutable_host_compute_request() = 1323 host_compute_request; 1324 1325 VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() 1326 << "), data handle " << handle.handle() << ": " 1327 << host_compute_request.ShortDebugString(); 1328 return handle; 1329 } 1330 1331 StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction( 1332 const DotRequest& dot_request) { 1333 tensorflow::mutex_lock lock(mutex_); 1334 1335 TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, 1336 LookUpRequest(dot_request.lhs())); 1337 TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, 1338 LookUpRequest(dot_request.rhs())); 1339 1340 TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( 1341 lhs->output_shape(), rhs->output_shape(), 1342 dot_request.dimension_numbers())); 1343 1344 const ComputationDataHandle handle = CreateComputationDataHandle(); 1345 1346 OperationRequest& request = 1347 (*session_computation_.mutable_requests())[handle.handle()]; 1348 *request.mutable_output_handle() = handle; 1349 *request.mutable_output_shape() = shape; 1350 *request.mutable_request()->mutable_dot_request() = dot_request; 1351 1352 VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() 1353 << "), data handle " << handle.handle() << ": " 1354 << dot_request.ShortDebugString(); 1355 return handle; 1356 } 1357 1358 StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction( 1359 const UnaryOpRequest& unary_request) { 1360 tensorflow::mutex_lock lock(mutex_); 1361 1362 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, 1363 LookUpRequest(unary_request.operand())); 1364 TF_ASSIGN_OR_RETURN( 1365 Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), 1366 operand->output_shape())); 1367 1368 ComputationDataHandle handle = CreateComputationDataHandle(); 1369 1370 OperationRequest& request = 1371 (*session_computation_.mutable_requests())[handle.handle()]; 1372 *request.mutable_output_handle() = handle; 1373 *request.mutable_output_shape() = shape; 1374 *request.mutable_request()->mutable_unary_op_request() = unary_request; 1375 1376 VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() 1377 << "), data handle " << handle.handle() << ": " 1378 << unary_request.ShortDebugString(); 1379 return handle; 1380 } 1381 1382 StatusOr<ComputationDataHandle> UserComputation::AddBinaryInstruction( 1383 const BinaryOpRequest& binary_request) { 1384 tensorflow::mutex_lock lock(mutex_); 1385 1386 TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, 1387 LookUpRequest(binary_request.lhs())); 1388 TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, 1389 LookUpRequest(binary_request.rhs())); 1390 TF_ASSIGN_OR_RETURN( 1391 Shape shape, 1392 ShapeInference::InferBinaryOpShape( 1393 binary_request.binop(), lhs->output_shape(), rhs->output_shape(), 1394 AsInt64Slice(binary_request.broadcast_dimensions()))); 1395 1396 ComputationDataHandle handle = CreateComputationDataHandle(); 1397 1398 OperationRequest& request = 1399 (*session_computation_.mutable_requests())[handle.handle()]; 1400 *request.mutable_output_handle() = handle; 1401 *request.mutable_output_shape() = shape; 1402 *request.mutable_request()->mutable_binary_op_request() = binary_request; 1403 1404 VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() 1405 << "), data handle " << handle.handle() << ": " 1406 << binary_request.ShortDebugString(); 1407 return handle; 1408 } 1409 1410 StatusOr<ComputationDataHandle> UserComputation::AddTernaryInstruction( 1411 const TernaryOpRequest& ternary_request) { 1412 tensorflow::mutex_lock lock(mutex_); 1413 1414 TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, 1415 LookUpRequest(ternary_request.lhs())); 1416 TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, 1417 LookUpRequest(ternary_request.rhs())); 1418 TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, 1419 LookUpRequest(ternary_request.ehs())); 1420 TF_ASSIGN_OR_RETURN(Shape shape, 1421 ShapeInference::InferTernaryOpShape( 1422 ternary_request.triop(), lhs->output_shape(), 1423 rhs->output_shape(), ehs->output_shape())); 1424 1425 ComputationDataHandle handle = CreateComputationDataHandle(); 1426 1427 OperationRequest& request = 1428 (*session_computation_.mutable_requests())[handle.handle()]; 1429 *request.mutable_output_handle() = handle; 1430 *request.mutable_output_shape() = shape; 1431 *request.mutable_request()->mutable_ternary_op_request() = ternary_request; 1432 1433 VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() 1434 << "), data handle " << handle.handle() << ": " 1435 << ternary_request.ShortDebugString(); 1436 return handle; 1437 } 1438 1439 StatusOr<ComputationDataHandle> UserComputation::AddVariadicInstruction( 1440 const VariadicOpRequest& variadic_request) { 1441 tensorflow::mutex_lock lock(mutex_); 1442 1443 std::vector<const Shape*> operand_shapes; 1444 for (const ComputationDataHandle& handle : variadic_request.operands()) { 1445 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); 1446 operand_shapes.push_back(&operand->output_shape()); 1447 } 1448 1449 TF_ASSIGN_OR_RETURN(Shape shape, 1450 ShapeInference::InferVariadicOpShape( 1451 variadic_request.varop(), operand_shapes)); 1452 1453 ComputationDataHandle handle = CreateComputationDataHandle(); 1454 1455 OperationRequest& request = 1456 (*session_computation_.mutable_requests())[handle.handle()]; 1457 *request.mutable_output_handle() = handle; 1458 *request.mutable_output_shape() = shape; 1459 *request.mutable_request()->mutable_variadic_op_request() = variadic_request; 1460 1461 VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() 1462 << "), data handle " << handle.handle() << ": " 1463 << variadic_request.ShortDebugString(); 1464 return handle; 1465 } 1466 1467 StatusOr<Shape> UserComputation::GetShape(const ComputationDataHandle& handle) { 1468 tensorflow::mutex_lock lock(mutex_); 1469 1470 TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); 1471 return operand->output_shape(); 1472 } 1473 1474 Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, 1475 const OpMetadata& metadata) { 1476 tensorflow::mutex_lock lock(mutex_); 1477 1478 int64 handle_value = handle.handle(); 1479 if (session_computation_.requests().count(handle_value) == 0) { 1480 return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", 1481 handle_value); 1482 } 1483 *session_computation_.mutable_requests() 1484 ->at(handle_value) 1485 .mutable_request() 1486 ->mutable_metadata() = metadata; 1487 return Status::OK(); 1488 } 1489 1490 Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, 1491 const OpSharding& sharding) { 1492 tensorflow::mutex_lock lock(mutex_); 1493 1494 int64 handle_value = handle.handle(); 1495 if (session_computation_.requests().count(handle_value) == 0) { 1496 return InvalidArgument("Invalid handle in SetOpSharding (%lld)", 1497 handle_value); 1498 } 1499 *session_computation_.mutable_requests() 1500 ->at(handle_value) 1501 .mutable_request() 1502 ->mutable_sharding() = sharding; 1503 return Status::OK(); 1504 } 1505 1506 Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { 1507 tensorflow::mutex_lock lock(mutex_); 1508 1509 if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { 1510 return InvalidArgument("Invalid handle in SetReturnValue"); 1511 } 1512 1513 handle_to_return_ = handle; 1514 1515 VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " 1516 << GetVersionedHandleInternal(); 1517 1518 return Status::OK(); 1519 } 1520 1521 VersionedComputationHandle UserComputation::GetVersionedHandle() const { 1522 tensorflow::mutex_lock lock(mutex_); 1523 return GetVersionedHandleInternal(); 1524 } 1525 1526 VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { 1527 VersionedComputationHandle versioned_handle; 1528 versioned_handle.handle = session_computation_.computation_handle(); 1529 1530 if (handle_to_return_.handle() > 0) { 1531 // A specific handle has been requested for the result of the computation. 1532 versioned_handle.version = handle_to_return_.handle(); 1533 } else { 1534 // A version value is simply the most recently assigned 1535 // ComputationDataHandle value, ie the handle value of the root of the 1536 // computation. 1537 versioned_handle.version = next_handle_value_ - 1; 1538 } 1539 1540 return versioned_handle; 1541 } 1542 1543 VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( 1544 const ComputationDataHandle& operation) const { 1545 tensorflow::mutex_lock lock(mutex_); 1546 1547 // The version at which an operation was added is simply the handle value of 1548 // the ComputationDataHandle. 1549 VersionedComputationHandle versioned_handle; 1550 versioned_handle.handle = session_computation_.computation_handle(); 1551 versioned_handle.version = operation.handle(); 1552 return versioned_handle; 1553 } 1554 1555 VersionedComputationHandle::Version UserComputation::version() const { 1556 return GetVersionedHandle().version; 1557 } 1558 1559 namespace { 1560 1561 // Returns true if the operation type corresponding to the given opcase can be 1562 // the root of the computation. 1563 bool CanBeRoot(const OpRequest::OpCase& op_case) { 1564 switch (op_case) { 1565 case OpRequest::kTraceRequest: 1566 case OpRequest::kSendRequest: 1567 case OpRequest::kOutfeedRequest: 1568 return false; 1569 default: 1570 return true; 1571 } 1572 } 1573 1574 // Returns a pointer to the operation with the given data handle value in the 1575 // given SessionComputation. 1576 StatusOr<const OperationRequest*> LookUpRequest( 1577 int64 handle_value, const SessionComputation& session_computation) { 1578 if (session_computation.requests().count(handle_value) == 0) { 1579 return InvalidArgument("no ComputationDataHandle value %lld", handle_value); 1580 } 1581 return &session_computation.requests().at(handle_value); 1582 } 1583 1584 // Returns the OperationRequest corresponding to the root (result) of the 1585 // session computation. 1586 StatusOr<const OperationRequest*> GetRoot( 1587 VersionedComputationHandle::Version version, 1588 const SessionComputation& session_computation) { 1589 TF_RET_CHECK(version > 0); 1590 // Not all instructions can be roots. Walk backwards from the operation 1591 // indicated by this version until a valid root is found. 1592 const OperationRequest* root_request = nullptr; 1593 while (version > 0) { 1594 TF_ASSIGN_OR_RETURN(root_request, 1595 LookUpRequest(version, session_computation)); 1596 if (CanBeRoot(root_request->request().op_case())) { 1597 break; 1598 } 1599 version--; 1600 } 1601 if (version == 0) { 1602 return InternalError("Computation contains no root operation"); 1603 } 1604 return root_request; 1605 } 1606 1607 } // namespace 1608 1609 StatusOr<std::shared_ptr<const ProgramShape>> 1610 UserComputation::ComputeProgramShape( 1611 VersionedComputationHandle::Version version) const { 1612 tensorflow::mutex_lock lock(mutex_); 1613 1614 TF_RET_CHECK(version > 0 && version < next_handle_value_); 1615 1616 if (program_shape_ == nullptr || program_shape_version_ != version) { 1617 // ProgramShape has not been computed yet, or is for different 1618 // version. Compute it now. 1619 TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); 1620 1621 auto program_shape = MakeUnique<ProgramShape>(); 1622 for (int64 request_num = 1; request_num <= version; ++request_num) { 1623 const OperationRequest& request = 1624 session_computation_.requests().at(request_num); 1625 if (request.request().op_case() == OpRequest::kParameterRequest) { 1626 const ParameterRequest& parameter_request = 1627 request.request().parameter_request(); 1628 int64 param_no = parameter_request.parameter(); 1629 // Parameters may be out of order so expand ProgramShape parameters 1630 // until it is at least large enough to hold the current parameter 1631 // number. 1632 while (program_shape->parameters_size() <= param_no) { 1633 program_shape->add_parameters(); 1634 program_shape->add_parameter_names(); 1635 } 1636 *program_shape->mutable_parameters(param_no) = request.output_shape(); 1637 *program_shape->mutable_parameter_names(param_no) = 1638 parameter_request.name(); 1639 } 1640 } 1641 1642 // The root determines the output shape. 1643 TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, 1644 GetRoot(version, session_computation_)); 1645 *program_shape->mutable_result() = root_request->output_shape(); 1646 if (ShapeUtil::IsOpaque(program_shape->result())) { 1647 return Unimplemented("Computation results cannot be opaque"); 1648 } 1649 1650 program_shape_ = std::move(program_shape); 1651 program_shape_version_ = version; 1652 } 1653 1654 return program_shape_; 1655 } 1656 1657 namespace { 1658 1659 // A visitor which checks whether an operation is pure functional meaning that 1660 // it doesn't depend on any parameter with an index higher then num_parameters. 1661 // The visitor walks the computation starting at a given operation and sets 1662 // is_functional to false iff a parameter or RNG operation is encountered. 1663 void PureFunctionalVisitor(const SessionComputation& session_computation, 1664 const ComputationDataHandle& handle, 1665 int64 num_parameters, std::set<int64>* visited, 1666 bool* is_functional) { 1667 if (visited->count(handle.handle()) != 0 || !*is_functional) { 1668 return; 1669 } 1670 1671 const OperationRequest& request = 1672 session_computation.requests().at(handle.handle()); 1673 switch (request.request().op_case()) { 1674 case OpRequest::kRngRequest: 1675 *is_functional = false; 1676 break; 1677 1678 case OpRequest::kConstantRequest: 1679 break; 1680 1681 case OpRequest::kGetTupleElementRequest: { 1682 const GetTupleElementRequest& get_tuple_element_request = 1683 request.request().get_tuple_element_request(); 1684 PureFunctionalVisitor(session_computation, 1685 get_tuple_element_request.operand(), num_parameters, 1686 visited, is_functional); 1687 break; 1688 } 1689 1690 case OpRequest::kSliceRequest: { 1691 const SliceRequest& slice_request = request.request().slice_request(); 1692 PureFunctionalVisitor(session_computation, slice_request.operand(), 1693 num_parameters, visited, is_functional); 1694 break; 1695 } 1696 1697 case OpRequest::kDynamicSliceRequest: { 1698 const DynamicSliceRequest& dynamic_slice_request = 1699 request.request().dynamic_slice_request(); 1700 PureFunctionalVisitor(session_computation, 1701 dynamic_slice_request.operand(), num_parameters, 1702 visited, is_functional); 1703 PureFunctionalVisitor(session_computation, 1704 dynamic_slice_request.start_indices(), 1705 num_parameters, visited, is_functional); 1706 break; 1707 } 1708 1709 case OpRequest::kDynamicUpdateSliceRequest: { 1710 const DynamicUpdateSliceRequest& dynamic_update_slice_request = 1711 request.request().dynamic_update_slice_request(); 1712 PureFunctionalVisitor(session_computation, 1713 dynamic_update_slice_request.operand(), 1714 num_parameters, visited, is_functional); 1715 PureFunctionalVisitor(session_computation, 1716 dynamic_update_slice_request.update(), 1717 num_parameters, visited, is_functional); 1718 PureFunctionalVisitor(session_computation, 1719 dynamic_update_slice_request.start_indices(), 1720 num_parameters, visited, is_functional); 1721 break; 1722 } 1723 1724 case OpRequest::kConcatenateRequest: { 1725 const ConcatenateRequest& concatenate_request = 1726 request.request().concatenate_request(); 1727 for (const ComputationDataHandle& handle : 1728 concatenate_request.operands()) { 1729 PureFunctionalVisitor(session_computation, handle, num_parameters, 1730 visited, is_functional); 1731 } 1732 break; 1733 } 1734 1735 case OpRequest::kConvolveRequest: { 1736 const ConvolveRequest& convolve_request = 1737 request.request().convolve_request(); 1738 PureFunctionalVisitor(session_computation, convolve_request.lhs(), 1739 num_parameters, visited, is_functional); 1740 PureFunctionalVisitor(session_computation, convolve_request.rhs(), 1741 num_parameters, visited, is_functional); 1742 break; 1743 } 1744 1745 case OpRequest::kFftRequest: { 1746 const FftRequest& fft_request = request.request().fft_request(); 1747 PureFunctionalVisitor(session_computation, fft_request.operand(), 1748 num_parameters, visited, is_functional); 1749 break; 1750 } 1751 1752 case OpRequest::kCrossReplicaSumRequest: { 1753 // TODO(b/33009255): Implmement constant folding for cross replica sum. 1754 *is_functional = false; 1755 break; 1756 } 1757 1758 case OpRequest::kInfeedRequest: { 1759 *is_functional = false; 1760 break; 1761 } 1762 1763 case OpRequest::kOutfeedRequest: { 1764 *is_functional = false; 1765 break; 1766 } 1767 1768 case OpRequest::kHostComputeRequest: { 1769 *is_functional = false; 1770 break; 1771 } 1772 1773 case OpRequest::kCallRequest: { 1774 const CallRequest& call_request = request.request().call_request(); 1775 for (const ComputationDataHandle& handle : call_request.operands()) { 1776 PureFunctionalVisitor(session_computation, handle, num_parameters, 1777 visited, is_functional); 1778 } 1779 // TODO(b/32495713): We aren't checking the to_apply computation itself, 1780 // so we conservatively say that computations containing the Call op 1781 // cannot be constant. We cannot set is_functional=false in other similar 1782 // cases since we're already relying on IsConstant to return true. 1783 *is_functional = false; 1784 break; 1785 } 1786 1787 case OpRequest::kCustomCallRequest: { 1788 *is_functional = false; 1789 break; 1790 } 1791 1792 case OpRequest::kDotRequest: { 1793 const DotRequest& dot_request = request.request().dot_request(); 1794 PureFunctionalVisitor(session_computation, dot_request.lhs(), 1795 num_parameters, visited, is_functional); 1796 PureFunctionalVisitor(session_computation, dot_request.rhs(), 1797 num_parameters, visited, is_functional); 1798 break; 1799 } 1800 1801 case OpRequest::kSendRequest: { 1802 *is_functional = false; 1803 break; 1804 } 1805 1806 case OpRequest::kRecvRequest: { 1807 *is_functional = false; 1808 break; 1809 } 1810 1811 case OpRequest::kMapRequest: { 1812 const MapRequest& map_request = request.request().map_request(); 1813 for (const ComputationDataHandle& handle : map_request.operands()) { 1814 PureFunctionalVisitor(session_computation, handle, num_parameters, 1815 visited, is_functional); 1816 } 1817 // TODO(b/32495713): We aren't checking the to_apply computation itself. 1818 break; 1819 } 1820 1821 case OpRequest::kReduceRequest: { 1822 const ReduceRequest& reduce_request = request.request().reduce_request(); 1823 PureFunctionalVisitor(session_computation, reduce_request.operand(), 1824 num_parameters, visited, is_functional); 1825 PureFunctionalVisitor(session_computation, reduce_request.init_value(), 1826 num_parameters, visited, is_functional); 1827 // TODO(b/32495713): We aren't checking the to_apply computation itself. 1828 break; 1829 } 1830 1831 case OpRequest::kReduceWindowRequest: { 1832 const ReduceWindowRequest& reduce_window_request = 1833 request.request().reduce_window_request(); 1834 PureFunctionalVisitor(session_computation, 1835 reduce_window_request.operand(), num_parameters, 1836 visited, is_functional); 1837 PureFunctionalVisitor(session_computation, 1838 reduce_window_request.init_value(), num_parameters, 1839 visited, is_functional); 1840 // TODO(b/32495713): We aren't checking the to_apply computation itself. 1841 break; 1842 } 1843 1844 case OpRequest::kSelectAndScatterRequest: { 1845 const SelectAndScatterRequest& select_and_scatter_request = 1846 request.request().select_and_scatter_request(); 1847 PureFunctionalVisitor(session_computation, 1848 select_and_scatter_request.operand(), 1849 num_parameters, visited, is_functional); 1850 PureFunctionalVisitor(session_computation, 1851 select_and_scatter_request.source(), num_parameters, 1852 visited, is_functional); 1853 PureFunctionalVisitor(session_computation, 1854 select_and_scatter_request.init_value(), 1855 num_parameters, visited, is_functional); 1856 // TODO(b/32495713): We aren't checking the select and scatter 1857 // computations themselves. 1858 break; 1859 } 1860 1861 case OpRequest::kBroadcastRequest: { 1862 const BroadcastRequest& broadcast_request = 1863 request.request().broadcast_request(); 1864 PureFunctionalVisitor(session_computation, broadcast_request.operand(), 1865 num_parameters, visited, is_functional); 1866 break; 1867 } 1868 1869 case OpRequest::kReshapeRequest: { 1870 const ReshapeRequest& reshape_request = 1871 request.request().reshape_request(); 1872 PureFunctionalVisitor(session_computation, reshape_request.operand(), 1873 num_parameters, visited, is_functional); 1874 break; 1875 } 1876 1877 case OpRequest::kReverseRequest: { 1878 const ReverseRequest& reverse_request = 1879 request.request().reverse_request(); 1880 PureFunctionalVisitor(session_computation, reverse_request.operand(), 1881 num_parameters, visited, is_functional); 1882 break; 1883 } 1884 1885 case OpRequest::kPadRequest: { 1886 const PadRequest& pad_request = request.request().pad_request(); 1887 PureFunctionalVisitor(session_computation, pad_request.operand(), 1888 num_parameters, visited, is_functional); 1889 PureFunctionalVisitor(session_computation, pad_request.padding_value(), 1890 num_parameters, visited, is_functional); 1891 break; 1892 } 1893 1894 case OpRequest::kParameterRequest: { 1895 const ParameterRequest& parameter_request = 1896 request.request().parameter_request(); 1897 if (parameter_request.parameter() >= num_parameters) { 1898 *is_functional = false; 1899 } 1900 break; 1901 } 1902 1903 case OpRequest::kConvertRequest: { 1904 const ConvertRequest& convert_request = 1905 request.request().convert_request(); 1906 PureFunctionalVisitor(session_computation, convert_request.operand(), 1907 num_parameters, visited, is_functional); 1908 break; 1909 } 1910 1911 case OpRequest::kBitcastConvertRequest: { 1912 const ConvertRequest& convert_request = 1913 request.request().bitcast_convert_request(); 1914 PureFunctionalVisitor(session_computation, convert_request.operand(), 1915 num_parameters, visited, is_functional); 1916 break; 1917 } 1918 1919 case OpRequest::kWhileRequest: { 1920 const WhileRequest& while_request = request.request().while_request(); 1921 PureFunctionalVisitor(session_computation, while_request.init(), 1922 num_parameters, visited, is_functional); 1923 // TODO(b/32495713): We aren't checking the condition and body 1924 // computations themselves. 1925 *is_functional = false; 1926 break; 1927 } 1928 1929 case OpRequest::kConditionalRequest: { 1930 const ConditionalRequest& conditional_request = 1931 request.request().conditional_request(); 1932 PureFunctionalVisitor(session_computation, 1933 conditional_request.predicate(), num_parameters, 1934 visited, is_functional); 1935 PureFunctionalVisitor(session_computation, 1936 conditional_request.true_operand(), num_parameters, 1937 visited, is_functional); 1938 PureFunctionalVisitor(session_computation, 1939 conditional_request.false_operand(), num_parameters, 1940 visited, is_functional); 1941 // TODO(b/32495713): We aren't checking the true and false computations 1942 // themselves. 1943 break; 1944 } 1945 1946 case OpRequest::kTernaryOpRequest: { 1947 const TernaryOpRequest& ternary_op_request = 1948 request.request().ternary_op_request(); 1949 PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), 1950 num_parameters, visited, is_functional); 1951 PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), 1952 num_parameters, visited, is_functional); 1953 PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), 1954 num_parameters, visited, is_functional); 1955 break; 1956 } 1957 1958 case OpRequest::kTransposeRequest: { 1959 const TransposeRequest& transpose_request = 1960 request.request().transpose_request(); 1961 PureFunctionalVisitor(session_computation, transpose_request.operand(), 1962 num_parameters, visited, is_functional); 1963 break; 1964 } 1965 1966 case OpRequest::kVariadicOpRequest: { 1967 const VariadicOpRequest& variadic_op_request = 1968 request.request().variadic_op_request(); 1969 for (const ComputationDataHandle& handle : 1970 variadic_op_request.operands()) { 1971 PureFunctionalVisitor(session_computation, handle, num_parameters, 1972 visited, is_functional); 1973 } 1974 break; 1975 } 1976 1977 case OpRequest::kUnaryOpRequest: { 1978 const UnaryOpRequest& unary_op_request = 1979 request.request().unary_op_request(); 1980 PureFunctionalVisitor(session_computation, unary_op_request.operand(), 1981 num_parameters, visited, is_functional); 1982 break; 1983 } 1984 1985 case OpRequest::kBatchNormTrainingRequest: { 1986 const BatchNormTrainingRequest& batch_norm_training_request = 1987 request.request().batch_norm_training_request(); 1988 PureFunctionalVisitor(session_computation, 1989 batch_norm_training_request.operand(), 1990 num_parameters, visited, is_functional); 1991 PureFunctionalVisitor(session_computation, 1992 batch_norm_training_request.scale(), num_parameters, 1993 visited, is_functional); 1994 PureFunctionalVisitor(session_computation, 1995 batch_norm_training_request.offset(), 1996 num_parameters, visited, is_functional); 1997 break; 1998 } 1999 2000 case OpRequest::kBatchNormInferenceRequest: { 2001 const BatchNormInferenceRequest& batch_norm_inference_request = 2002 request.request().batch_norm_inference_request(); 2003 PureFunctionalVisitor(session_computation, 2004 batch_norm_inference_request.operand(), 2005 num_parameters, visited, is_functional); 2006 PureFunctionalVisitor(session_computation, 2007 batch_norm_inference_request.scale(), 2008 num_parameters, visited, is_functional); 2009 PureFunctionalVisitor(session_computation, 2010 batch_norm_inference_request.offset(), 2011 num_parameters, visited, is_functional); 2012 PureFunctionalVisitor(session_computation, 2013 batch_norm_inference_request.mean(), num_parameters, 2014 visited, is_functional); 2015 PureFunctionalVisitor(session_computation, 2016 batch_norm_inference_request.variance(), 2017 num_parameters, visited, is_functional); 2018 break; 2019 } 2020 2021 case OpRequest::kBatchNormGradRequest: { 2022 const BatchNormGradRequest& batch_norm_grad_request = 2023 request.request().batch_norm_grad_request(); 2024 PureFunctionalVisitor(session_computation, 2025 batch_norm_grad_request.operand(), num_parameters, 2026 visited, is_functional); 2027 PureFunctionalVisitor(session_computation, 2028 batch_norm_grad_request.scale(), num_parameters, 2029 visited, is_functional); 2030 PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), 2031 num_parameters, visited, is_functional); 2032 PureFunctionalVisitor(session_computation, 2033 batch_norm_grad_request.variance(), num_parameters, 2034 visited, is_functional); 2035 PureFunctionalVisitor(session_computation, 2036 batch_norm_grad_request.grad_output(), 2037 num_parameters, visited, is_functional); 2038 break; 2039 } 2040 2041 case OpRequest::kBinaryOpRequest: { 2042 const BinaryOpRequest& binary_op_request = 2043 request.request().binary_op_request(); 2044 PureFunctionalVisitor(session_computation, binary_op_request.lhs(), 2045 num_parameters, visited, is_functional); 2046 PureFunctionalVisitor(session_computation, binary_op_request.rhs(), 2047 num_parameters, visited, is_functional); 2048 break; 2049 } 2050 2051 case OpRequest::kGatherRequest: { 2052 PureFunctionalVisitor(session_computation, 2053 request.request().gather_request().input(), 2054 num_parameters, visited, is_functional); 2055 PureFunctionalVisitor(session_computation, 2056 request.request().gather_request().gather_indices(), 2057 num_parameters, visited, is_functional); 2058 break; 2059 } 2060 2061 case OpRequest::OP_NOT_SET: 2062 LOG(FATAL) << "OperationRequest doesn't contain a request"; 2063 2064 default: 2065 LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); 2066 } 2067 if (!*is_functional) { 2068 VLOG(1) << "Non-functional: " << request.request().DebugString(); 2069 } 2070 visited->insert(handle.handle()); 2071 } 2072 2073 } // namespace 2074 2075 StatusOr<bool> UserComputation::IsConstant(const ComputationDataHandle& handle, 2076 int64 num_parameters) { 2077 tensorflow::mutex_lock lock(mutex_); 2078 2079 // Verify that the handle is valid. 2080 auto operation_status = LookUpRequest(handle); 2081 if (!operation_status.ok()) { 2082 return operation_status.status(); 2083 } 2084 2085 bool is_constant = true; 2086 std::set<int64> visited; 2087 PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, 2088 &is_constant); 2089 2090 return is_constant; 2091 } 2092 2093 std::vector<VersionedComputationHandle> 2094 UserComputation::GetEmbeddedComputations( 2095 VersionedComputationHandle::Version version) const { 2096 tensorflow::mutex_lock lock(mutex_); 2097 2098 VLOG(1) 2099 << "GetEmbeddedComputations(" << name() << " " 2100 << VersionedComputationHandle{session_computation_.computation_handle(), 2101 version} 2102 << ")"; 2103 XLA_VLOG_LINES(3, session_computation_.DebugString()); 2104 2105 std::vector<VersionedComputationHandle> computations; 2106 std::vector<int64> sorted_handles; 2107 for (const auto& handle_request : session_computation_.requests()) { 2108 sorted_handles.push_back(handle_request.first); 2109 } 2110 std::sort(sorted_handles.begin(), sorted_handles.end()); 2111 for (int64 handle : sorted_handles) { 2112 const auto& handle_request = session_computation_.requests().find(handle); 2113 CHECK(handle_request != session_computation_.requests().end()); 2114 int64 handle_value = handle_request->first; 2115 if (handle_value <= version) { 2116 const OperationRequest& request = handle_request->second; 2117 switch (request.request().op_case()) { 2118 case OpRequest::kCallRequest: { 2119 CHECK_EQ(1, request.embedded_computation_versions_size()); 2120 const CallRequest& call_request = request.request().call_request(); 2121 const VersionedComputationHandle versioned_handle = { 2122 call_request.to_apply(), 2123 request.embedded_computation_versions(0)}; 2124 computations.push_back(versioned_handle); 2125 break; 2126 } 2127 2128 case OpRequest::kMapRequest: { 2129 CHECK_EQ(1, request.embedded_computation_versions_size()); 2130 const MapRequest& map_request = request.request().map_request(); 2131 const VersionedComputationHandle versioned_handle = { 2132 map_request.to_apply(), request.embedded_computation_versions(0)}; 2133 computations.push_back(versioned_handle); 2134 break; 2135 } 2136 2137 case OpRequest::kReduceRequest: { 2138 CHECK_EQ(1, request.embedded_computation_versions_size()); 2139 const ReduceRequest& reduce_request = 2140 request.request().reduce_request(); 2141 const VersionedComputationHandle versioned_handle = { 2142 reduce_request.to_apply(), 2143 request.embedded_computation_versions(0)}; 2144 computations.push_back(versioned_handle); 2145 break; 2146 } 2147 2148 case OpRequest::kReduceWindowRequest: { 2149 CHECK_EQ(1, request.embedded_computation_versions_size()); 2150 const ReduceWindowRequest& reduce_window_request = 2151 request.request().reduce_window_request(); 2152 const VersionedComputationHandle versioned_handle = { 2153 reduce_window_request.to_apply(), 2154 request.embedded_computation_versions(0)}; 2155 computations.push_back(versioned_handle); 2156 break; 2157 } 2158 2159 case OpRequest::kSelectAndScatterRequest: { 2160 CHECK_EQ(2, request.embedded_computation_versions_size()); 2161 const SelectAndScatterRequest& select_and_scatter_request = 2162 request.request().select_and_scatter_request(); 2163 const VersionedComputationHandle select_versioned_handle = { 2164 select_and_scatter_request.select(), 2165 request.embedded_computation_versions(0)}; 2166 computations.push_back(select_versioned_handle); 2167 const VersionedComputationHandle scatter_versioned_handle = { 2168 select_and_scatter_request.scatter(), 2169 request.embedded_computation_versions(1)}; 2170 computations.push_back(scatter_versioned_handle); 2171 break; 2172 } 2173 2174 case OpRequest::kWhileRequest: { 2175 CHECK_EQ(2, request.embedded_computation_versions_size()); 2176 const WhileRequest& while_request = request.request().while_request(); 2177 const VersionedComputationHandle condition_versioned_handle = { 2178 while_request.condition(), 2179 request.embedded_computation_versions(0)}; 2180 computations.push_back(condition_versioned_handle); 2181 const VersionedComputationHandle body_versioned_handle = { 2182 while_request.body(), request.embedded_computation_versions(1)}; 2183 computations.push_back(body_versioned_handle); 2184 break; 2185 } 2186 2187 case OpRequest::kConditionalRequest: { 2188 CHECK_EQ(2, request.embedded_computation_versions_size()); 2189 const ConditionalRequest& conditional_request = 2190 request.request().conditional_request(); 2191 const VersionedComputationHandle true_computation_versioned_handle = { 2192 conditional_request.true_computation(), 2193 request.embedded_computation_versions(0)}; 2194 computations.push_back(true_computation_versioned_handle); 2195 const VersionedComputationHandle false_computation_versioned_handle = 2196 {conditional_request.false_computation(), 2197 request.embedded_computation_versions(1)}; 2198 computations.push_back(false_computation_versioned_handle); 2199 break; 2200 } 2201 2202 default: 2203 // No embedded computation. 2204 break; 2205 } 2206 } 2207 } 2208 VLOG(2) << "Embedded computations: " 2209 << tensorflow::str_util::Join( 2210 computations, ", ", 2211 [](string* out, const VersionedComputationHandle& h) { 2212 out->append(h.ToString()); 2213 }); 2214 return computations; 2215 } 2216 2217 StatusOr<const OperationRequest*> 2218 UserComputation::LookUpRequestForErrorReporting( 2219 const ComputationDataHandle& handle) const { 2220 tensorflow::mutex_lock lock(mutex_); 2221 return LookUpRequest(handle); 2222 } 2223 2224 tensorflow::gtl::optional<const OpMetadata*> UserComputation::ParameterMetadata( 2225 int parameter_number) const { 2226 tensorflow::mutex_lock lock(mutex_); 2227 auto it = parameters_.find(parameter_number); 2228 if (it == parameters_.end()) { 2229 return tensorflow::gtl::nullopt; 2230 } 2231 OperationRequest* op = it->second; 2232 return &op->request().metadata(); 2233 } 2234 2235 Status UserComputation::RemapEmbeddedComputations( 2236 const std::map<int64, ComputationHandle>& old_to_new) { 2237 auto update = [&old_to_new](ComputationHandle* to_update) -> Status { 2238 int64 old = to_update->handle(); 2239 auto it = old_to_new.find(old); 2240 if (it == old_to_new.end()) { 2241 string mapping = tensorflow::str_util::Join( 2242 old_to_new, ", ", 2243 [](string* out, std::pair<int64, ComputationHandle> element) { 2244 tensorflow::strings::Appendf(out, "%lld:%lld", element.first, 2245 element.second.handle()); 2246 }); 2247 return NotFound( 2248 "could not find referenced (old) computation handle in mapping: " 2249 "%lld; mapping: {%s}", 2250 old, mapping.c_str()); 2251 } 2252 VLOG(2) << "remapping " << old << " to " << it->second.handle(); 2253 *to_update = it->second; 2254 return Status::OK(); 2255 }; 2256 TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); 2257 for (auto& handle_request : *session_computation_.mutable_requests()) { 2258 OperationRequest& request = handle_request.second; 2259 switch (request.request().op_case()) { 2260 case OpRequest::kCallRequest: { 2261 TF_RET_CHECK(1 == request.embedded_computation_versions_size()); 2262 CallRequest* call_request = 2263 request.mutable_request()->mutable_call_request(); 2264 TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); 2265 break; 2266 } 2267 case OpRequest::kMapRequest: { 2268 TF_RET_CHECK(1 == request.embedded_computation_versions_size()); 2269 MapRequest* map_request = 2270 request.mutable_request()->mutable_map_request(); 2271 TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); 2272 break; 2273 } 2274 case OpRequest::kReduceRequest: { 2275 TF_RET_CHECK(1 == request.embedded_computation_versions_size()); 2276 ReduceRequest* reduce_request = 2277 request.mutable_request()->mutable_reduce_request(); 2278 TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); 2279 break; 2280 } 2281 case OpRequest::kReduceWindowRequest: { 2282 TF_RET_CHECK(1 == request.embedded_computation_versions_size()); 2283 ReduceWindowRequest* reduce_window_request = 2284 request.mutable_request()->mutable_reduce_window_request(); 2285 TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); 2286 break; 2287 } 2288 case OpRequest::kSelectAndScatterRequest: { 2289 TF_RET_CHECK(2 == request.embedded_computation_versions_size()); 2290 SelectAndScatterRequest* select_and_scatter_request = 2291 request.mutable_request()->mutable_select_and_scatter_request(); 2292 TF_RETURN_IF_ERROR( 2293 update(select_and_scatter_request->mutable_select())); 2294 TF_RETURN_IF_ERROR( 2295 update(select_and_scatter_request->mutable_scatter())); 2296 break; 2297 } 2298 case OpRequest::kWhileRequest: { 2299 TF_RET_CHECK(2 == request.embedded_computation_versions_size()); 2300 WhileRequest* while_request = 2301 request.mutable_request()->mutable_while_request(); 2302 TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); 2303 TF_RETURN_IF_ERROR(update(while_request->mutable_body())); 2304 break; 2305 } 2306 case OpRequest::kConditionalRequest: { 2307 TF_RET_CHECK(2 == request.embedded_computation_versions_size()); 2308 ConditionalRequest* conditional_request = 2309 request.mutable_request()->mutable_conditional_request(); 2310 TF_RETURN_IF_ERROR( 2311 update(conditional_request->mutable_true_computation())); 2312 TF_RETURN_IF_ERROR( 2313 update(conditional_request->mutable_false_computation())); 2314 break; 2315 } 2316 default: 2317 // No embedded computation. 2318 TF_RET_CHECK(0 == request.embedded_computation_versions_size()); 2319 break; 2320 } 2321 } 2322 return Status::OK(); 2323 } 2324 2325 SessionComputation UserComputation::CloneSessionComputation( 2326 VersionedComputationHandle::Version version) const { 2327 tensorflow::mutex_lock lock(mutex_); 2328 SessionComputation result = session_computation_; 2329 // Erase all the requests that exceed the version specified. 2330 // There's no lower_bound method on tensorflow::protobuf::Map so we iterate 2331 // all the elements. 2332 auto it = result.mutable_requests()->begin(); 2333 while (it != result.mutable_requests()->end()) { 2334 if (it->first > version) { 2335 it = result.mutable_requests()->erase(it); 2336 } else { 2337 ++it; 2338 } 2339 } 2340 return result; 2341 } 2342 2343 StatusOr<const OperationRequest*> UserComputation::LookUpRequest( 2344 const ComputationDataHandle& handle) const { 2345 int64 handle_value = handle.handle(); 2346 if (session_computation_.requests().count(handle_value) == 0) { 2347 return InvalidArgument("no ComputationDataHandle value %lld", handle_value); 2348 } 2349 return &session_computation_.requests().at(handle_value); 2350 } 2351 2352 Status UserComputation::CheckParametersAreContiguous( 2353 VersionedComputationHandle::Version version) const { 2354 TF_RET_CHECK(version > 0 && version < next_handle_value_); 2355 2356 // Determine number of parameter inputs at the given version. 2357 std::map<int64, const ParameterRequest*> parameter_requests; 2358 for (int64 request_num = 1; request_num <= version; ++request_num) { 2359 const OperationRequest& request = 2360 session_computation_.requests().at(request_num); 2361 2362 if (request.request().op_case() == OpRequest::kParameterRequest) { 2363 const ParameterRequest& parameter_request = 2364 request.request().parameter_request(); 2365 // Duplicate parameters should be checked when parameter requests are 2366 // added. 2367 TF_RET_CHECK(0 == 2368 parameter_requests.count(parameter_request.parameter())); 2369 parameter_requests[parameter_request.parameter()] = ¶meter_request; 2370 } 2371 } 2372 2373 for (int64 i = 0; i < parameter_requests.size(); ++i) { 2374 auto it = parameter_requests.find(i); 2375 if (it == parameter_requests.end()) { 2376 return FailedPrecondition( 2377 "computation %s does not have all its parameters populated " 2378 "sequentially, missing parameter %lld", 2379 name_.c_str(), i); 2380 } 2381 } 2382 2383 return Status::OK(); 2384 } 2385 2386 namespace { 2387 2388 // Helper class which builds an HLO computation from a SessionComputation. To 2389 // construct the HLO computation, the SessionComputation graph is walked in 2390 // DFS order lowering each OperationRequest to an HLO instruction. 2391 class ComputationLowerer { 2392 public: 2393 static StatusOr<std::unique_ptr<HloComputation>> Lower( 2394 const string& computation_name, 2395 const SessionComputation& session_computation, 2396 VersionedComputationHandle::Version version, 2397 UserComputation::HloComputationResolver hlo_resolver, 2398 const DebugOptions& debug_options, 2399 bool include_unreachable_instructions) { 2400 ComputationLowerer lowerer(computation_name, session_computation, version, 2401 std::move(hlo_resolver), debug_options, 2402 include_unreachable_instructions); 2403 return lowerer.Lower(); 2404 } 2405 2406 private: 2407 ComputationLowerer(const string& computation_name, 2408 const SessionComputation& session_computation, 2409 VersionedComputationHandle::Version version, 2410 UserComputation::HloComputationResolver hlo_resolver, 2411 const DebugOptions& debug_options, 2412 bool include_unreachable_instructions) 2413 : hlo_builder_(computation_name), 2414 session_computation_(session_computation), 2415 version_(version), 2416 hlo_resolver_(std::move(hlo_resolver)), 2417 debug_options_(debug_options), 2418 include_unreachable_instructions_(include_unreachable_instructions) {} 2419 2420 // Build an HLO computation from the SessionComputation at the given 2421 // version. 2422 StatusOr<std::unique_ptr<HloComputation>> Lower(); 2423 2424 private: 2425 // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. 2426 void TraversePostorder( 2427 const ComputationDataHandle& root, 2428 std::unordered_map<int64, HloInstruction*>* visited, 2429 const std::function<void(const ComputationDataHandle&)>& visit); 2430 2431 // DFS visitor of the UserComputation operations which lowers the operations 2432 // to HLO instructions. 2433 void Visit(const ComputationDataHandle& handle, 2434 std::unordered_map<int64, HloInstruction*>* instructions); 2435 2436 // Resolves a ComputationHandle and Version to a previously lowered 2437 // HloComputation using the hlo_resolver_ function. 2438 HloComputation* ResolveComputation( 2439 const ComputationHandle& handle, 2440 VersionedComputationHandle::Version version); 2441 2442 // This function takes an input value which is being implicitly broadcast into 2443 // an output shape and figures out the right kBroadcast instruction(s) 2444 // necessary to replicate the implicit broadcast semantics explicitly. 2445 HloInstruction* ImplicitBroadcastToExplicitBroadcast( 2446 HloInstruction* operand, const Shape& output_shape); 2447 2448 HloComputation::Builder hlo_builder_; 2449 const SessionComputation& session_computation_; 2450 const VersionedComputationHandle::Version version_; 2451 const UserComputation::HloComputationResolver hlo_resolver_; 2452 const DebugOptions& debug_options_; 2453 const bool include_unreachable_instructions_; 2454 }; 2455 2456 // Calls 'apply' on each operand of 'request'. 2457 static void ForEachOperand( 2458 const OperationRequest& request, 2459 const std::function<void(const ComputationDataHandle& param)>& apply) { 2460 switch (request.request().op_case()) { 2461 case OpRequest::kRngRequest: { 2462 const RngRequest& rng_request = request.request().rng_request(); 2463 for (const ComputationDataHandle& param : rng_request.parameter()) { 2464 apply(param); 2465 } 2466 break; 2467 } 2468 2469 case OpRequest::kConstantRequest: 2470 break; 2471 case OpRequest::kGetTupleElementRequest: { 2472 const GetTupleElementRequest& get_tuple_element_request = 2473 request.request().get_tuple_element_request(); 2474 apply(get_tuple_element_request.operand()); 2475 break; 2476 } 2477 2478 case OpRequest::kSliceRequest: { 2479 const SliceRequest& slice_request = request.request().slice_request(); 2480 apply(slice_request.operand()); 2481 break; 2482 } 2483 2484 case OpRequest::kDynamicSliceRequest: { 2485 const DynamicSliceRequest& dynamic_slice_request = 2486 request.request().dynamic_slice_request(); 2487 apply(dynamic_slice_request.operand()); 2488 apply(dynamic_slice_request.start_indices()); 2489 break; 2490 } 2491 2492 case OpRequest::kDynamicUpdateSliceRequest: { 2493 const DynamicUpdateSliceRequest& dynamic_update_slice_request = 2494 request.request().dynamic_update_slice_request(); 2495 apply(dynamic_update_slice_request.operand()); 2496 apply(dynamic_update_slice_request.update()); 2497 apply(dynamic_update_slice_request.start_indices()); 2498 break; 2499 } 2500 2501 case OpRequest::kConcatenateRequest: { 2502 const ConcatenateRequest& concatenate_request = 2503 request.request().concatenate_request(); 2504 for (const ComputationDataHandle& handle : 2505 concatenate_request.operands()) { 2506 apply(handle); 2507 } 2508 break; 2509 } 2510 2511 case OpRequest::kConvolveRequest: { 2512 const ConvolveRequest& convolve_request = 2513 request.request().convolve_request(); 2514 apply(convolve_request.lhs()); 2515 apply(convolve_request.rhs()); 2516 break; 2517 } 2518 2519 case OpRequest::kFftRequest: { 2520 const FftRequest& fft_request = request.request().fft_request(); 2521 apply(fft_request.operand()); 2522 break; 2523 } 2524 2525 case OpRequest::kBatchNormTrainingRequest: { 2526 const BatchNormTrainingRequest& batch_norm_training_request = 2527 request.request().batch_norm_training_request(); 2528 2529 apply(batch_norm_training_request.operand()); 2530 apply(batch_norm_training_request.scale()); 2531 apply(batch_norm_training_request.offset()); 2532 break; 2533 } 2534 2535 case OpRequest::kBatchNormInferenceRequest: { 2536 const BatchNormInferenceRequest& batch_norm_inference_request = 2537 request.request().batch_norm_inference_request(); 2538 2539 apply(batch_norm_inference_request.operand()); 2540 apply(batch_norm_inference_request.scale()); 2541 apply(batch_norm_inference_request.offset()); 2542 apply(batch_norm_inference_request.mean()); 2543 apply(batch_norm_inference_request.variance()); 2544 break; 2545 } 2546 2547 case OpRequest::kBatchNormGradRequest: { 2548 const BatchNormGradRequest& batch_norm_grad_request = 2549 request.request().batch_norm_grad_request(); 2550 2551 apply(batch_norm_grad_request.operand()); 2552 apply(batch_norm_grad_request.scale()); 2553 apply(batch_norm_grad_request.mean()); 2554 apply(batch_norm_grad_request.variance()); 2555 apply(batch_norm_grad_request.grad_output()); 2556 break; 2557 } 2558 2559 case OpRequest::kCrossReplicaSumRequest: { 2560 const CrossReplicaSumRequest& cross_replica_sum_request = 2561 request.request().cross_replica_sum_request(); 2562 apply(cross_replica_sum_request.operand()); 2563 break; 2564 } 2565 2566 case OpRequest::kInfeedRequest: 2567 break; 2568 2569 case OpRequest::kOutfeedRequest: { 2570 const OutfeedRequest& outfeed_request = 2571 request.request().outfeed_request(); 2572 apply(outfeed_request.operand()); 2573 break; 2574 } 2575 2576 case OpRequest::kMapRequest: { 2577 const MapRequest& map_request = request.request().map_request(); 2578 for (const ComputationDataHandle& handle : map_request.operands()) { 2579 apply(handle); 2580 } 2581 break; 2582 } 2583 2584 case OpRequest::kReduceRequest: { 2585 const ReduceRequest& reduce_request = request.request().reduce_request(); 2586 apply(reduce_request.operand()); 2587 apply(reduce_request.init_value()); 2588 break; 2589 } 2590 2591 case OpRequest::kReduceWindowRequest: { 2592 const ReduceWindowRequest& reduce_window_request = 2593 request.request().reduce_window_request(); 2594 apply(reduce_window_request.operand()); 2595 apply(reduce_window_request.init_value()); 2596 break; 2597 } 2598 2599 case OpRequest::kSelectAndScatterRequest: { 2600 const SelectAndScatterRequest& select_and_scatter_request = 2601 request.request().select_and_scatter_request(); 2602 apply(select_and_scatter_request.operand()); 2603 apply(select_and_scatter_request.source()); 2604 apply(select_and_scatter_request.init_value()); 2605 2606 break; 2607 } 2608 2609 case OpRequest::kBroadcastRequest: { 2610 const BroadcastRequest& broadcast_request = 2611 request.request().broadcast_request(); 2612 apply(broadcast_request.operand()); 2613 break; 2614 } 2615 2616 case OpRequest::kReshapeRequest: { 2617 const ReshapeRequest& reshape_request = 2618 request.request().reshape_request(); 2619 apply(reshape_request.operand()); 2620 break; 2621 } 2622 2623 case OpRequest::kTransposeRequest: { 2624 const TransposeRequest& transpose_request = 2625 request.request().transpose_request(); 2626 apply(transpose_request.operand()); 2627 break; 2628 } 2629 2630 case OpRequest::kReverseRequest: { 2631 const ReverseRequest& reverse_request = 2632 request.request().reverse_request(); 2633 apply(reverse_request.operand()); 2634 break; 2635 } 2636 2637 case OpRequest::kPadRequest: { 2638 const PadRequest& pad_request = request.request().pad_request(); 2639 apply(pad_request.operand()); 2640 apply(pad_request.padding_value()); 2641 break; 2642 } 2643 2644 case OpRequest::kRecvRequest: 2645 case OpRequest::kParameterRequest: 2646 break; 2647 2648 case OpRequest::kConvertRequest: { 2649 const ConvertRequest& convert_request = 2650 request.request().convert_request(); 2651 apply(convert_request.operand()); 2652 break; 2653 } 2654 2655 case OpRequest::kBitcastConvertRequest: { 2656 const ConvertRequest& convert_request = 2657 request.request().bitcast_convert_request(); 2658 apply(convert_request.operand()); 2659 break; 2660 } 2661 2662 case OpRequest::kWhileRequest: { 2663 const WhileRequest& while_request = request.request().while_request(); 2664 apply(while_request.init()); 2665 break; 2666 } 2667 2668 case OpRequest::kConditionalRequest: { 2669 const ConditionalRequest& conditional_request = 2670 request.request().conditional_request(); 2671 apply(conditional_request.predicate()); 2672 apply(conditional_request.true_operand()); 2673 apply(conditional_request.false_operand()); 2674 break; 2675 } 2676 2677 case OpRequest::kTernaryOpRequest: { 2678 const TernaryOpRequest& ternary_op_request = 2679 request.request().ternary_op_request(); 2680 apply(ternary_op_request.lhs()); 2681 apply(ternary_op_request.rhs()); 2682 apply(ternary_op_request.ehs()); 2683 break; 2684 } 2685 2686 case OpRequest::kVariadicOpRequest: { 2687 const VariadicOpRequest& variadic_op_request = 2688 request.request().variadic_op_request(); 2689 for (const ComputationDataHandle& handle : 2690 variadic_op_request.operands()) { 2691 apply(handle); 2692 } 2693 break; 2694 } 2695 2696 case OpRequest::kCallRequest: { 2697 const CallRequest& call_request = request.request().call_request(); 2698 for (const ComputationDataHandle& handle : call_request.operands()) { 2699 apply(handle); 2700 } 2701 break; 2702 } 2703 2704 case OpRequest::kCustomCallRequest: { 2705 const CustomCallRequest& cc_request = 2706 request.request().custom_call_request(); 2707 for (const ComputationDataHandle& operand : cc_request.operands()) { 2708 apply(operand); 2709 } 2710 break; 2711 } 2712 2713 case OpRequest::kHostComputeRequest: { 2714 const HostComputeRequest& hc_request = 2715 request.request().host_compute_request(); 2716 for (const ComputationDataHandle& operand : hc_request.operands()) { 2717 apply(operand); 2718 } 2719 break; 2720 } 2721 2722 case OpRequest::kDotRequest: { 2723 const DotRequest& dot_request = request.request().dot_request(); 2724 apply(dot_request.rhs()); 2725 apply(dot_request.lhs()); 2726 break; 2727 } 2728 2729 case OpRequest::kUnaryOpRequest: { 2730 const UnaryOpRequest& unary_op_request = 2731 request.request().unary_op_request(); 2732 apply(unary_op_request.operand()); 2733 break; 2734 } 2735 2736 case OpRequest::kBinaryOpRequest: { 2737 const BinaryOpRequest& binary_op_request = 2738 request.request().binary_op_request(); 2739 apply(binary_op_request.rhs()); 2740 apply(binary_op_request.lhs()); 2741 break; 2742 } 2743 2744 case OpRequest::kReducePrecisionRequest: { 2745 const ReducePrecisionRequest& reduce_precision_request = 2746 request.request().reduce_precision_request(); 2747 apply(reduce_precision_request.operand()); 2748 break; 2749 } 2750 2751 case OpRequest::kTraceRequest: { 2752 const TraceRequest& trace_request = request.request().trace_request(); 2753 apply(trace_request.operand()); 2754 break; 2755 } 2756 2757 case OpRequest::kSendRequest: { 2758 const SendRequest& send_request = request.request().send_request(); 2759 apply(send_request.operand()); 2760 break; 2761 } 2762 2763 case OpRequest::kGatherRequest: { 2764 const GatherRequest& gather_request = request.request().gather_request(); 2765 apply(gather_request.input()); 2766 apply(gather_request.gather_indices()); 2767 break; 2768 } 2769 2770 case OpRequest::OP_NOT_SET: 2771 LOG(FATAL) << "OperationRequest doesn't contain a request"; 2772 2773 default: 2774 LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); 2775 } 2776 } 2777 2778 void ComputationLowerer::TraversePostorder( 2779 const ComputationDataHandle& root, 2780 std::unordered_map<int64, HloInstruction*>* visited, 2781 const std::function<void(const ComputationDataHandle&)>& visit) { 2782 // Stack containing {handle, enter} pairs. The 'enter' value describes whether 2783 // we are entering or leaving 'handle'. 2784 std::stack<std::pair<ComputationDataHandle, bool>> work; 2785 work.push({root, true}); 2786 while (!work.empty()) { 2787 ComputationDataHandle handle; 2788 bool enter; 2789 std::tie(handle, enter) = work.top(); 2790 work.pop(); 2791 2792 if (enter) { 2793 // We are entering 'handle'. The first time we enter 'handle', we add it 2794 // to 'visited' with a nullptr value. If 'handle' is already in 'visited', 2795 // we do not visit it again. This algorithm only uses the presence of 2796 // a handle in 'visited', but we use a map so we can use the same data 2797 // structure to store the HloInstruction outputs. 2798 if (visited->emplace(handle.handle(), nullptr).second) { 2799 const OperationRequest& request = 2800 session_computation_.requests().at(handle.handle()); 2801 // Push the corresponding 'leave' action onto the stack, followed by 2802 // the operands. 2803 work.push({handle, false}); 2804 ForEachOperand(request, [&work](const ComputationDataHandle& child) { 2805 work.push({child, true}); 2806 }); 2807 } 2808 } else { 2809 // We are leaving 'handle'. We have visited the operands of 'handle', and 2810 // now can visit the 'handle' itself. 2811 visit(handle); 2812 } 2813 } 2814 } 2815 2816 StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower() { 2817 // Map from ComputationDataHandle to HLO instruction. Serves as a record of 2818 // which operations have been visited as well as a cache for looking up 2819 // ComputationDataHandles as HloInstructions. 2820 std::unordered_map<int64, HloInstruction*> instructions; 2821 2822 TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, 2823 GetRoot(version_, session_computation_)); 2824 2825 auto visit = [&](const ComputationDataHandle& handle) { 2826 Visit(handle, &instructions); 2827 }; 2828 TraversePostorder(root_request->output_handle(), &instructions, visit); 2829 HloInstruction* hlo_root = 2830 instructions.at(root_request->output_handle().handle()); 2831 2832 if (include_unreachable_instructions_) { 2833 // Iterate through all computation data handles, and visit any unvisited 2834 // operations. 2835 for (int64 request_num = 1; request_num <= version_; ++request_num) { 2836 TF_ASSIGN_OR_RETURN(const OperationRequest* request, 2837 LookUpRequest(request_num, session_computation_)); 2838 TraversePostorder(request->output_handle(), &instructions, visit); 2839 } 2840 } 2841 2842 return hlo_builder_.Build(hlo_root); 2843 } 2844 2845 HloComputation* ComputationLowerer::ResolveComputation( 2846 const ComputationHandle& handle, 2847 VersionedComputationHandle::Version version) { 2848 const VersionedComputationHandle checked_handle = {handle, version}; 2849 return hlo_resolver_(checked_handle); 2850 } 2851 2852 HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( 2853 HloInstruction* operand, const Shape& output_shape) { 2854 auto fadd = [this](std::unique_ptr<HloInstruction> x) { 2855 return hlo_builder_.AddInstruction(std::move(x)); 2856 }; 2857 return fadd( 2858 HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); 2859 } 2860 2861 void ComputationLowerer::Visit( 2862 const ComputationDataHandle& handle, 2863 std::unordered_map<int64, HloInstruction*>* instructions) { 2864 CHECK_LE(handle.handle(), version_); 2865 CHECK(instructions->at(handle.handle()) == nullptr); 2866 const OperationRequest& request = 2867 session_computation_.requests().at(handle.handle()); 2868 auto add_instruction = [&](std::unique_ptr<HloInstruction> instruction) { 2869 HloInstruction* hlo_instruction = 2870 hlo_builder_.AddInstruction(std::move(instruction)); 2871 hlo_instruction->set_metadata(request.request().metadata()); 2872 if (request.request().has_sharding()) { 2873 OpSharding op_sharding = request.request().sharding(); 2874 hlo_instruction->set_sharding( 2875 HloSharding::FromProto(op_sharding).ValueOrDie()); 2876 } 2877 return hlo_instruction; 2878 }; 2879 auto lookup_instruction = [&](const ComputationDataHandle& handle) { 2880 return instructions->at(handle.handle()); 2881 }; 2882 HloInstruction* hlo_instruction; 2883 switch (request.request().op_case()) { 2884 case OpRequest::kRngRequest: { 2885 const RngRequest& rng_request = request.request().rng_request(); 2886 std::vector<HloInstruction*> parameters; 2887 for (const ComputationDataHandle& param : rng_request.parameter()) { 2888 parameters.push_back(lookup_instruction(param)); 2889 } 2890 hlo_instruction = add_instruction(HloInstruction::CreateRng( 2891 request.output_shape(), rng_request.distribution(), parameters)); 2892 break; 2893 } 2894 2895 case OpRequest::kConstantRequest: { 2896 const ConstantRequest& constant_request = 2897 request.request().constant_request(); 2898 hlo_instruction = add_instruction(HloInstruction::CreateConstant( 2899 Literal::CreateFromProto(constant_request.literal()) 2900 .ConsumeValueOrDie())); 2901 break; 2902 } 2903 2904 case OpRequest::kGetTupleElementRequest: { 2905 const GetTupleElementRequest& get_tuple_element_request = 2906 request.request().get_tuple_element_request(); 2907 HloInstruction* operand = 2908 lookup_instruction(get_tuple_element_request.operand()); 2909 hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( 2910 request.output_shape(), operand, get_tuple_element_request.index())); 2911 break; 2912 } 2913 2914 case OpRequest::kSliceRequest: { 2915 const SliceRequest& slice_request = request.request().slice_request(); 2916 HloInstruction* operand = lookup_instruction(slice_request.operand()); 2917 hlo_instruction = add_instruction(HloInstruction::CreateSlice( 2918 request.output_shape(), operand, 2919 AsInt64Slice(slice_request.start_indices()), 2920 AsInt64Slice(slice_request.limit_indices()), 2921 AsInt64Slice(slice_request.strides()))); 2922 break; 2923 } 2924 2925 case OpRequest::kDynamicSliceRequest: { 2926 const DynamicSliceRequest& dynamic_slice_request = 2927 request.request().dynamic_slice_request(); 2928 HloInstruction* operand = 2929 lookup_instruction(dynamic_slice_request.operand()); 2930 HloInstruction* start_indices = 2931 lookup_instruction(dynamic_slice_request.start_indices()); 2932 2933 hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( 2934 request.output_shape(), operand, start_indices, 2935 AsInt64Slice(dynamic_slice_request.slice_sizes()))); 2936 break; 2937 } 2938 2939 case OpRequest::kDynamicUpdateSliceRequest: { 2940 const DynamicUpdateSliceRequest& dynamic_update_slice_request = 2941 request.request().dynamic_update_slice_request(); 2942 HloInstruction* operand = 2943 lookup_instruction(dynamic_update_slice_request.operand()); 2944 HloInstruction* update = 2945 lookup_instruction(dynamic_update_slice_request.update()); 2946 HloInstruction* start_indices = 2947 lookup_instruction(dynamic_update_slice_request.start_indices()); 2948 hlo_instruction = 2949 add_instruction(HloInstruction::CreateDynamicUpdateSlice( 2950 request.output_shape(), operand, update, start_indices)); 2951 break; 2952 } 2953 2954 case OpRequest::kConcatenateRequest: { 2955 const ConcatenateRequest& concatenate_request = 2956 request.request().concatenate_request(); 2957 std::vector<HloInstruction*> operands; 2958 for (const ComputationDataHandle& handle : 2959 concatenate_request.operands()) { 2960 HloInstruction* operand = lookup_instruction(handle); 2961 operands.push_back(operand); 2962 } 2963 hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( 2964 request.output_shape(), operands, concatenate_request.dimension())); 2965 break; 2966 } 2967 2968 case OpRequest::kConvolveRequest: { 2969 const ConvolveRequest& convolve_request = 2970 request.request().convolve_request(); 2971 HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); 2972 HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); 2973 hlo_instruction = add_instruction(HloInstruction::CreateConvolve( 2974 request.output_shape(), lhs, rhs, convolve_request.window(), 2975 convolve_request.dimension_numbers())); 2976 break; 2977 } 2978 2979 case OpRequest::kFftRequest: { 2980 const FftRequest& fft_request = request.request().fft_request(); 2981 HloInstruction* operand = lookup_instruction(fft_request.operand()); 2982 hlo_instruction = add_instruction(HloInstruction::CreateFft( 2983 request.output_shape(), operand, fft_request.fft_type(), 2984 AsInt64Slice(fft_request.fft_length()))); 2985 break; 2986 } 2987 2988 case OpRequest::kDotRequest: { 2989 const DotRequest& dot_request = request.request().dot_request(); 2990 HloInstruction* lhs = lookup_instruction(dot_request.lhs()); 2991 HloInstruction* rhs = lookup_instruction(dot_request.rhs()); 2992 hlo_instruction = add_instruction(HloInstruction::CreateDot( 2993 request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); 2994 break; 2995 } 2996 2997 case OpRequest::kCrossReplicaSumRequest: { 2998 const CrossReplicaSumRequest& cross_replica_sum_request = 2999 request.request().cross_replica_sum_request(); 3000 HloInstruction* operand = 3001 lookup_instruction(cross_replica_sum_request.operand()); 3002 hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( 3003 request.output_shape(), {operand})); 3004 break; 3005 } 3006 3007 case OpRequest::kInfeedRequest: { 3008 const InfeedRequest& infeed_request = request.request().infeed_request(); 3009 hlo_instruction = add_instruction(HloInstruction::CreateInfeed( 3010 request.output_shape(), infeed_request.config())); 3011 break; 3012 } 3013 3014 case OpRequest::kOutfeedRequest: { 3015 const OutfeedRequest& outfeed_request = 3016 request.request().outfeed_request(); 3017 HloInstruction* operand = lookup_instruction(outfeed_request.operand()); 3018 hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( 3019 outfeed_request.shape(), operand, outfeed_request.outfeed_config())); 3020 break; 3021 } 3022 3023 case OpRequest::kMapRequest: { 3024 const MapRequest& map_request = request.request().map_request(); 3025 std::vector<HloInstruction*> operands; 3026 for (const ComputationDataHandle& handle : map_request.operands()) { 3027 HloInstruction* operand = lookup_instruction(handle); 3028 operands.push_back(operand); 3029 } 3030 CHECK_EQ(1, request.embedded_computation_versions_size()); 3031 VersionedComputationHandle::Version map_version = 3032 request.embedded_computation_versions(0); 3033 HloComputation* map_computation = 3034 ResolveComputation(map_request.to_apply(), map_version); 3035 hlo_instruction = add_instruction(HloInstruction::CreateMap( 3036 request.output_shape(), operands, map_computation)); 3037 break; 3038 } 3039 3040 case OpRequest::kReduceRequest: { 3041 const ReduceRequest& reduce_request = request.request().reduce_request(); 3042 HloInstruction* operand = lookup_instruction(reduce_request.operand()); 3043 HloInstruction* init_value = 3044 lookup_instruction(reduce_request.init_value()); 3045 CHECK_EQ(1, request.embedded_computation_versions_size()); 3046 VersionedComputationHandle::Version reduce_version = 3047 request.embedded_computation_versions(0); 3048 HloComputation* reduce_computation = 3049 ResolveComputation(reduce_request.to_apply(), reduce_version); 3050 hlo_instruction = add_instruction(HloInstruction::CreateReduce( 3051 request.output_shape(), operand, init_value, 3052 AsInt64Slice(reduce_request.dimensions()), reduce_computation)); 3053 break; 3054 } 3055 3056 case OpRequest::kReduceWindowRequest: { 3057 const ReduceWindowRequest& reduce_window_request = 3058 request.request().reduce_window_request(); 3059 HloInstruction* operand = 3060 lookup_instruction(reduce_window_request.operand()); 3061 HloInstruction* init_value = 3062 lookup_instruction(reduce_window_request.init_value()); 3063 CHECK_EQ(1, request.embedded_computation_versions_size()); 3064 VersionedComputationHandle::Version reduce_window_version = 3065 request.embedded_computation_versions(0); 3066 HloComputation* reduce_window_computation = ResolveComputation( 3067 reduce_window_request.to_apply(), reduce_window_version); 3068 hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( 3069 request.output_shape(), operand, init_value, 3070 reduce_window_request.window(), reduce_window_computation)); 3071 break; 3072 } 3073 3074 case OpRequest::kSelectAndScatterRequest: { 3075 const SelectAndScatterRequest& select_and_scatter_request = 3076 request.request().select_and_scatter_request(); 3077 HloInstruction* operand = 3078 lookup_instruction(select_and_scatter_request.operand()); 3079 HloInstruction* source = 3080 lookup_instruction(select_and_scatter_request.source()); 3081 HloInstruction* init_value = 3082 lookup_instruction(select_and_scatter_request.init_value()); 3083 CHECK_EQ(2, request.embedded_computation_versions_size()); 3084 VersionedComputationHandle::Version select_version = 3085 request.embedded_computation_versions(0); 3086 VersionedComputationHandle::Version scatter_version = 3087 request.embedded_computation_versions(1); 3088 HloComputation* select_computation = ResolveComputation( 3089 select_and_scatter_request.select(), select_version); 3090 HloComputation* scatter_computation = ResolveComputation( 3091 select_and_scatter_request.scatter(), scatter_version); 3092 hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( 3093 request.output_shape(), operand, select_computation, 3094 select_and_scatter_request.window(), source, init_value, 3095 scatter_computation)); 3096 break; 3097 } 3098 3099 case OpRequest::kBatchNormTrainingRequest: { 3100 const BatchNormTrainingRequest& batch_norm_training_request = 3101 request.request().batch_norm_training_request(); 3102 HloInstruction* operand = 3103 lookup_instruction(batch_norm_training_request.operand()); 3104 HloInstruction* scale = 3105 lookup_instruction(batch_norm_training_request.scale()); 3106 HloInstruction* offset = 3107 lookup_instruction(batch_norm_training_request.offset()); 3108 3109 hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( 3110 request.output_shape(), operand, scale, offset, 3111 batch_norm_training_request.epsilon(), 3112 batch_norm_training_request.feature_index())); 3113 break; 3114 } 3115 3116 case OpRequest::kBatchNormInferenceRequest: { 3117 const BatchNormInferenceRequest& batch_norm_inference_request = 3118 request.request().batch_norm_inference_request(); 3119 HloInstruction* operand = 3120 lookup_instruction(batch_norm_inference_request.operand()); 3121 HloInstruction* scale = 3122 lookup_instruction(batch_norm_inference_request.scale()); 3123 HloInstruction* offset = 3124 lookup_instruction(batch_norm_inference_request.offset()); 3125 HloInstruction* mean = 3126 lookup_instruction(batch_norm_inference_request.mean()); 3127 HloInstruction* variance = 3128 lookup_instruction(batch_norm_inference_request.variance()); 3129 3130 hlo_instruction = 3131 add_instruction(HloInstruction::CreateBatchNormInference( 3132 request.output_shape(), operand, scale, offset, mean, variance, 3133 batch_norm_inference_request.epsilon(), 3134 batch_norm_inference_request.feature_index())); 3135 break; 3136 } 3137 3138 case OpRequest::kBatchNormGradRequest: { 3139 const BatchNormGradRequest& batch_norm_grad_request = 3140 request.request().batch_norm_grad_request(); 3141 3142 HloInstruction* operand = 3143 lookup_instruction(batch_norm_grad_request.operand()); 3144 HloInstruction* scale = 3145 lookup_instruction(batch_norm_grad_request.scale()); 3146 HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); 3147 HloInstruction* variance = 3148 lookup_instruction(batch_norm_grad_request.variance()); 3149 HloInstruction* grad_output = 3150 lookup_instruction(batch_norm_grad_request.grad_output()); 3151 3152 hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( 3153 request.output_shape(), operand, scale, mean, variance, grad_output, 3154 batch_norm_grad_request.epsilon(), 3155 batch_norm_grad_request.feature_index())); 3156 break; 3157 } 3158 3159 case OpRequest::kBroadcastRequest: { 3160 const BroadcastRequest& broadcast_request = 3161 request.request().broadcast_request(); 3162 HloInstruction* operand = lookup_instruction(broadcast_request.operand()); 3163 std::vector<int64> broadcast_dimensions; 3164 // The client-level broadcast instruction just appends dimensions on the 3165 // left (adds lowest numbered dimensions). The HLO broadcast op is more 3166 // flexible and can add new dimensions anywhere. The broadcast_dimensions 3167 // maps operand dimensions to dimensions in the broadcast output, so 3168 // to append dimensions on the left the broadcast_dimensions should just 3169 // be the n highest dimension numbers of the output shape where n is 3170 // the number of input dimensions. 3171 broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); 3172 for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { 3173 broadcast_dimensions.push_back(i + 3174 ShapeUtil::Rank(request.output_shape()) - 3175 ShapeUtil::Rank(operand->shape())); 3176 } 3177 hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( 3178 request.output_shape(), operand, broadcast_dimensions)); 3179 break; 3180 } 3181 3182 case OpRequest::kReshapeRequest: { 3183 const ReshapeRequest& reshape_request = 3184 request.request().reshape_request(); 3185 HloInstruction* operand = lookup_instruction(reshape_request.operand()); 3186 HloInstruction* transposed; 3187 if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { 3188 transposed = operand; 3189 } else { 3190 transposed = add_instruction(HloInstruction::CreateTranspose( 3191 ShapeUtil::PermuteDimensions( 3192 InversePermutation(AsInt64Slice(reshape_request.dimensions())), 3193 operand->shape()), 3194 operand, AsInt64Slice(reshape_request.dimensions()))); 3195 } 3196 hlo_instruction = add_instruction( 3197 HloInstruction::CreateReshape(request.output_shape(), transposed)); 3198 break; 3199 } 3200 3201 case OpRequest::kTransposeRequest: { 3202 const TransposeRequest& transpose_request = 3203 request.request().transpose_request(); 3204 HloInstruction* operand = lookup_instruction(transpose_request.operand()); 3205 hlo_instruction = add_instruction(HloInstruction::CreateTranspose( 3206 ShapeUtil::PermuteDimensions( 3207 InversePermutation(AsInt64Slice(transpose_request.dimensions())), 3208 operand->shape()), 3209 operand, AsInt64Slice(transpose_request.dimensions()))); 3210 break; 3211 } 3212 3213 case OpRequest::kReverseRequest: { 3214 const ReverseRequest& reverse_request = 3215 request.request().reverse_request(); 3216 HloInstruction* operand = lookup_instruction(reverse_request.operand()); 3217 hlo_instruction = add_instruction(HloInstruction::CreateReverse( 3218 request.output_shape(), operand, 3219 AsInt64Slice(reverse_request.dimensions()))); 3220 break; 3221 } 3222 3223 case OpRequest::kPadRequest: { 3224 const PadRequest& pad_request = request.request().pad_request(); 3225 HloInstruction* operand = lookup_instruction(pad_request.operand()); 3226 HloInstruction* padding_value = 3227 lookup_instruction(pad_request.padding_value()); 3228 hlo_instruction = add_instruction(HloInstruction::CreatePad( 3229 request.output_shape(), operand, padding_value, 3230 pad_request.padding_config())); 3231 break; 3232 } 3233 3234 case OpRequest::kRecvRequest: { 3235 const RecvRequest& recv_request = request.request().recv_request(); 3236 HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( 3237 request.output_shape(), recv_request.channel_handle().handle())); 3238 hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); 3239 break; 3240 } 3241 3242 case OpRequest::kParameterRequest: { 3243 const ParameterRequest& parameter_request = 3244 request.request().parameter_request(); 3245 hlo_instruction = add_instruction(HloInstruction::CreateParameter( 3246 parameter_request.parameter(), request.output_shape(), 3247 parameter_request.name())); 3248 break; 3249 } 3250 3251 case OpRequest::kConvertRequest: { 3252 const ConvertRequest& convert_request = 3253 request.request().convert_request(); 3254 HloInstruction* operand = lookup_instruction(convert_request.operand()); 3255 hlo_instruction = add_instruction( 3256 HloInstruction::CreateConvert(request.output_shape(), operand)); 3257 break; 3258 } 3259 3260 case OpRequest::kBitcastConvertRequest: { 3261 const ConvertRequest& convert_request = 3262 request.request().bitcast_convert_request(); 3263 HloInstruction* operand = lookup_instruction(convert_request.operand()); 3264 hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( 3265 request.output_shape(), operand)); 3266 break; 3267 } 3268 3269 case OpRequest::kWhileRequest: { 3270 const WhileRequest& while_request = request.request().while_request(); 3271 CHECK_EQ(2, request.embedded_computation_versions_size()); 3272 VersionedComputationHandle::Version condition_version = 3273 request.embedded_computation_versions(0); 3274 HloComputation* condition = 3275 ResolveComputation(while_request.condition(), condition_version); 3276 VersionedComputationHandle::Version body_version = 3277 request.embedded_computation_versions(1); 3278 HloComputation* body = 3279 ResolveComputation(while_request.body(), body_version); 3280 HloInstruction* init = lookup_instruction(while_request.init()); 3281 hlo_instruction = add_instruction(HloInstruction::CreateWhile( 3282 request.output_shape(), condition, body, init)); 3283 break; 3284 } 3285 3286 case OpRequest::kConditionalRequest: { 3287 const ConditionalRequest& conditional_request = 3288 request.request().conditional_request(); 3289 CHECK_EQ(2, request.embedded_computation_versions_size()); 3290 VersionedComputationHandle::Version true_computation_version = 3291 request.embedded_computation_versions(0); 3292 HloComputation* true_computation = ResolveComputation( 3293 conditional_request.true_computation(), true_computation_version); 3294 VersionedComputationHandle::Version false_computation_version = 3295 request.embedded_computation_versions(1); 3296 HloComputation* false_computation = ResolveComputation( 3297 conditional_request.false_computation(), false_computation_version); 3298 HloInstruction* predicate = 3299 lookup_instruction(conditional_request.predicate()); 3300 HloInstruction* true_operand = 3301 lookup_instruction(conditional_request.true_operand()); 3302 HloInstruction* false_operand = 3303 lookup_instruction(conditional_request.false_operand()); 3304 hlo_instruction = add_instruction(HloInstruction::CreateConditional( 3305 request.output_shape(), predicate, true_operand, true_computation, 3306 false_operand, false_computation)); 3307 break; 3308 } 3309 3310 case OpRequest::kTernaryOpRequest: { 3311 const TernaryOpRequest& ternary_op_request = 3312 request.request().ternary_op_request(); 3313 HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); 3314 HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); 3315 HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); 3316 auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); 3317 3318 if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { 3319 if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { 3320 // lhs side is being implicitly broadcast. Change to explicit. 3321 lhs = 3322 ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); 3323 } 3324 3325 if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { 3326 rhs = 3327 ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); 3328 } 3329 3330 if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { 3331 ehs = 3332 ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); 3333 } 3334 } 3335 3336 hlo_instruction = add_instruction(HloInstruction::CreateTernary( 3337 request.output_shape(), hlo_opcode, lhs, rhs, ehs)); 3338 break; 3339 } 3340 3341 case OpRequest::kVariadicOpRequest: { 3342 const VariadicOpRequest& variadic_op_request = 3343 request.request().variadic_op_request(); 3344 std::vector<HloInstruction*> operands; 3345 for (const ComputationDataHandle& handle : 3346 variadic_op_request.operands()) { 3347 HloInstruction* operand = lookup_instruction(handle); 3348 operands.push_back(operand); 3349 } 3350 auto hlo_opcode = 3351 VariadicOperationToHloOpcode(variadic_op_request.varop()); 3352 hlo_instruction = add_instruction(HloInstruction::CreateVariadic( 3353 request.output_shape(), hlo_opcode, operands)); 3354 break; 3355 } 3356 3357 case OpRequest::kCallRequest: { 3358 const CallRequest& call_request = request.request().call_request(); 3359 std::vector<HloInstruction*> operands; 3360 for (const ComputationDataHandle& handle : call_request.operands()) { 3361 operands.push_back(lookup_instruction(handle)); 3362 } 3363 CHECK_EQ(1, request.embedded_computation_versions_size()); 3364 VersionedComputationHandle::Version call_version = 3365 request.embedded_computation_versions(0); 3366 HloComputation* call_computation = 3367 ResolveComputation(call_request.to_apply(), call_version); 3368 hlo_instruction = add_instruction(HloInstruction::CreateCall( 3369 request.output_shape(), operands, call_computation)); 3370 break; 3371 } 3372 3373 case OpRequest::kCustomCallRequest: { 3374 const CustomCallRequest& cc_request = 3375 request.request().custom_call_request(); 3376 std::vector<HloInstruction*> operands; 3377 for (const ComputationDataHandle& operand : cc_request.operands()) { 3378 operands.push_back(lookup_instruction(operand)); 3379 } 3380 hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( 3381 cc_request.shape(), operands, cc_request.call_target_name())); 3382 break; 3383 } 3384 3385 case OpRequest::kHostComputeRequest: { 3386 const HostComputeRequest& host_compute_request = 3387 request.request().host_compute_request(); 3388 std::vector<HloInstruction*> operands; 3389 for (const ComputationDataHandle& operand : 3390 host_compute_request.operands()) { 3391 operands.push_back(lookup_instruction(operand)); 3392 } 3393 auto output_shape = host_compute_request.shape(); 3394 auto channel_name = host_compute_request.channel_name(); 3395 auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); 3396 hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( 3397 output_shape, operands, channel_name, cost_estimate_ns)); 3398 break; 3399 } 3400 3401 case OpRequest::kUnaryOpRequest: { 3402 const UnaryOpRequest& unary_op_request = 3403 request.request().unary_op_request(); 3404 HloInstruction* operand = lookup_instruction(unary_op_request.operand()); 3405 auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); 3406 hlo_instruction = add_instruction(HloInstruction::CreateUnary( 3407 request.output_shape(), hlo_opcode, operand)); 3408 break; 3409 } 3410 3411 case OpRequest::kBinaryOpRequest: { 3412 const BinaryOpRequest& binary_op_request = 3413 request.request().binary_op_request(); 3414 HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); 3415 HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); 3416 auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); 3417 if (binary_op_request.broadcast_dimensions_size() > 0 && 3418 ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { 3419 // Emit a broadcast instruction to perform the "broadcast in dimension" 3420 // operation. 3421 HloInstruction* operand_to_broadcast = 3422 ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs 3423 : rhs; 3424 CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), 3425 binary_op_request.broadcast_dimensions().size()); 3426 3427 // Construct the bounds of the shape of the kBroadcast instruction 3428 // responsible for the in-dimension broadcast. 3429 std::vector<int64> output_dimensions; 3430 for (int64 size : request.output_shape().dimensions()) { 3431 output_dimensions.push_back(size); 3432 } 3433 for (int64 operand_dim = 0; 3434 operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); 3435 ++operand_dim) { 3436 int64 output_dim = 3437 binary_op_request.broadcast_dimensions()[operand_dim]; 3438 output_dimensions[output_dim] = 3439 operand_to_broadcast->shape().dimensions(operand_dim); 3440 } 3441 3442 Shape broadcast_shape = ShapeUtil::MakeShape( 3443 operand_to_broadcast->shape().element_type(), output_dimensions); 3444 3445 // The broadcast semantics of a client-level binary op broadcast is 3446 // identical to the HLO broadcast semantics so the broadcast_dimensions 3447 // field can just be passed to the instruction builder. 3448 HloInstruction* broadcasted_operand = 3449 add_instruction(HloInstruction::CreateBroadcast( 3450 broadcast_shape, operand_to_broadcast, 3451 AsInt64Slice(binary_op_request.broadcast_dimensions()))); 3452 3453 lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; 3454 rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; 3455 } 3456 if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { 3457 if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { 3458 // lhs side is being implicitly broadcast. Change to explicit. 3459 lhs = 3460 ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); 3461 } 3462 3463 if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { 3464 rhs = 3465 ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); 3466 } 3467 } 3468 hlo_instruction = add_instruction(HloInstruction::CreateBinary( 3469 request.output_shape(), hlo_opcode, lhs, rhs)); 3470 break; 3471 } 3472 3473 case OpRequest::kReducePrecisionRequest: { 3474 const ReducePrecisionRequest& reduce_precision_request = 3475 request.request().reduce_precision_request(); 3476 HloInstruction* operand = 3477 lookup_instruction(reduce_precision_request.operand()); 3478 auto exponent_bits = reduce_precision_request.exponent_bits(); 3479 auto mantissa_bits = reduce_precision_request.mantissa_bits(); 3480 hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( 3481 request.output_shape(), operand, exponent_bits, mantissa_bits)); 3482 break; 3483 } 3484 3485 case OpRequest::kTraceRequest: { 3486 const TraceRequest& trace_request = request.request().trace_request(); 3487 HloInstruction* operand = lookup_instruction(trace_request.operand()); 3488 hlo_instruction = add_instruction( 3489 HloInstruction::CreateTrace(trace_request.tag(), operand)); 3490 operand->set_tracing(hlo_instruction); 3491 break; 3492 } 3493 3494 case OpRequest::kSendRequest: { 3495 const SendRequest& send_request = request.request().send_request(); 3496 HloInstruction* operand = lookup_instruction(send_request.operand()); 3497 HloInstruction* send = add_instruction(HloInstruction::CreateSend( 3498 operand, send_request.channel_handle().handle())); 3499 hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); 3500 break; 3501 } 3502 3503 case OpRequest::kGatherRequest: { 3504 const GatherRequest& gather_request = request.request().gather_request(); 3505 HloInstruction* input_operand = 3506 lookup_instruction(gather_request.input()); 3507 HloInstruction* gather_indices_operand = 3508 lookup_instruction(gather_request.gather_indices()); 3509 std::vector<int64> window_bounds; 3510 c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); 3511 hlo_instruction = add_instruction(HloInstruction::CreateGather( 3512 request.output_shape(), input_operand, gather_indices_operand, 3513 gather_request.dimension_numbers(), window_bounds)); 3514 break; 3515 } 3516 3517 case OpRequest::OP_NOT_SET: 3518 LOG(FATAL) << "OperationRequest doesn't contain a request"; 3519 3520 default: 3521 LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); 3522 } 3523 (*instructions)[handle.handle()] = hlo_instruction; 3524 } // NOLINT(readability/fn_size) 3525 3526 } // namespace 3527 3528 StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation( 3529 VersionedComputationHandle::Version version, 3530 HloComputationResolver hlo_resolver, const DebugOptions& debug_options, 3531 bool include_unreachable_instructions) const { 3532 tensorflow::mutex_lock lock(mutex_); 3533 3534 VLOG(2) << "Building HloComputation from UserComputation " << name_ 3535 << " at version " << version; 3536 XLA_VLOG_LINES(3, session_computation_.DebugString()); 3537 3538 TF_ASSIGN_OR_RETURN( 3539 std::unique_ptr<HloComputation> hlo_computation, 3540 ComputationLowerer::Lower( 3541 tensorflow::strings::StrCat(name(), ".v", version), 3542 session_computation_, version, std::move(hlo_resolver), debug_options, 3543 include_unreachable_instructions)); 3544 3545 return std::move(hlo_computation); 3546 } 3547 3548 } // namespace xla 3549