1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <set> 17 18 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 19 #include "tensorflow/compiler/xla/status_macros.h" 20 #include "tensorflow/core/lib/core/errors.h" 21 #include "tensorflow/core/lib/gtl/flatmap.h" 22 23 namespace xla { 24 25 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { 26 return CheckUnaryShape(hlo); 27 } 28 29 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) { 30 return CheckBinaryShape(hlo); 31 } 32 33 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) { 34 return CheckTernaryShape(clamp); 35 } 36 37 Status ShapeVerifier::HandleSelect(HloInstruction* select) { 38 return CheckTernaryShape(select); 39 } 40 41 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { 42 std::vector<const Shape*> operand_shapes; 43 for (const HloInstruction* operand : concatenate->operands()) { 44 operand_shapes.push_back(&operand->shape()); 45 } 46 return CheckShape(concatenate, 47 ShapeInference::InferConcatOpShape( 48 operand_shapes, concatenate->concatenate_dimension())); 49 } 50 51 Status ShapeVerifier::HandleConvert(HloInstruction* convert) { 52 return CheckShape(convert, ShapeInference::InferConvertShape( 53 convert->operand(0)->shape(), 54 convert->shape().element_type())); 55 } 56 57 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { 58 return CheckShape(convert, ShapeInference::InferBitcastConvertShape( 59 convert->operand(0)->shape(), 60 convert->shape().element_type())); 61 } 62 63 Status ShapeVerifier::HandleCopy(HloInstruction* copy) { 64 return CheckUnaryShape(copy); 65 } 66 67 Status ShapeVerifier::HandleDot(HloInstruction* dot) { 68 TF_ASSIGN_OR_RETURN(const Shape expected, 69 ShapeInference::InferDotOpShape( 70 dot->operand(0)->shape(), dot->operand(1)->shape(), 71 dot->dot_dimension_numbers())); 72 return CheckShape(dot, expected); 73 } 74 75 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { 76 TF_ASSIGN_OR_RETURN( 77 const Shape expected, 78 ShapeInference::InferConvolveShape( 79 convolution->operand(0)->shape(), convolution->operand(1)->shape(), 80 convolution->window(), convolution->convolution_dimension_numbers())); 81 return CheckShape(convolution, expected); 82 } 83 84 Status ShapeVerifier::HandleFft(HloInstruction* fft) { 85 TF_ASSIGN_OR_RETURN( 86 const Shape expected, 87 ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), 88 fft->fft_length())); 89 return CheckShape(fft, expected); 90 } 91 92 Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { 93 std::vector<const Shape*> operand_shapes; 94 for (const HloInstruction* operand : crs->operands()) { 95 operand_shapes.push_back(&operand->shape()); 96 } 97 return CheckShape(crs, 98 ShapeInference::InferCrossReplicaSumShape(operand_shapes)); 99 } 100 101 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { 102 return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( 103 reduce_precision->operand(0)->shape(), 104 reduce_precision->exponent_bits(), 105 reduce_precision->mantissa_bits())); 106 } 107 108 Status ShapeVerifier::HandleInfeed(HloInstruction*) { 109 return tensorflow::Status::OK(); 110 } 111 112 Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { 113 // Outfeed has a separate shape field for the value which is outfed to the 114 // host. The shape of the instruction itself is always nil because the outfeed 115 // produces no HLO value in the graph. 116 if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), 117 outfeed->operand(0)->shape())) { 118 return InvalidArgument( 119 "Expected outfeed to have shape compatible with operand's shape %s, " 120 "actual shape is %s:\n%s", 121 ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), 122 ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), 123 outfeed->ToString().c_str()); 124 } 125 return CheckShape(outfeed, ShapeUtil::MakeNil()); 126 } 127 128 Status ShapeVerifier::HandleHostCompute(HloInstruction*) { 129 return tensorflow::Status::OK(); 130 } 131 132 Status ShapeVerifier::HandleRng(HloInstruction*) { 133 return tensorflow::Status::OK(); 134 } 135 136 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { 137 return CheckShape( 138 reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), 139 reverse->dimensions())); 140 } 141 142 Status ShapeVerifier::HandleSort(HloInstruction* sort) { 143 return CheckUnaryShape(sort); 144 } 145 146 Status ShapeVerifier::HandleConstant(HloInstruction* constant) { 147 return CheckShape(constant, constant->literal().shape()); 148 } 149 150 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { 151 return CheckShape(get_tuple_element, 152 ShapeInference::InferGetTupleElementShape( 153 get_tuple_element->operand(0)->shape(), 154 get_tuple_element->tuple_index())); 155 } 156 157 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { 158 return CheckShape( 159 reduce, 160 ShapeInference::InferReduceShape( 161 reduce->operand(0)->shape(), reduce->operand(1)->shape(), 162 reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); 163 } 164 165 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { 166 return tensorflow::Status::OK(); 167 } 168 169 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { 170 // HLO broadcast has no exact analog at the proto level so there is no 171 // ShapeInference method. Check the output shape explicitly. 172 const Shape& operand_shape = broadcast->operand(0)->shape(); 173 // Check for mixed precision. 174 TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); 175 TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == 176 broadcast->dimensions().size()); 177 for (int64 operand_dimension = 0; 178 operand_dimension < ShapeUtil::Rank(operand_shape); 179 ++operand_dimension) { 180 int64 output_dimension = broadcast->dimensions()[operand_dimension]; 181 TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == 182 operand_shape.dimensions(operand_dimension)) 183 << broadcast->ToString() << " operand shape " << operand_shape; 184 } 185 return tensorflow::Status::OK(); 186 } 187 188 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { 189 // Check for mixed precision. 190 TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); 191 TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == 192 ShapeUtil::ElementsIn(reshape->operand(0)->shape())); 193 return tensorflow::Status::OK(); 194 } 195 196 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { 197 return CheckShape( 198 transpose, ShapeInference::InferTransposeShape( 199 transpose->operand(0)->shape(), transpose->dimensions())); 200 } 201 202 Status ShapeVerifier::HandleParameter(HloInstruction*) { 203 return tensorflow::Status::OK(); 204 } 205 206 Status ShapeVerifier::HandleFusion(HloInstruction*) { 207 return tensorflow::Status::OK(); 208 } 209 210 Status ShapeVerifier::HandleCall(HloInstruction* call) { 211 // The shape of kCall should match the shape of the computation it calls. 212 return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); 213 } 214 215 Status ShapeVerifier::HandleCustomCall(HloInstruction*) { 216 return tensorflow::Status::OK(); 217 } 218 219 Status ShapeVerifier::HandleSlice(HloInstruction* slice) { 220 return CheckShape(slice, 221 ShapeInference::InferSliceShape( 222 slice->operand(0)->shape(), slice->slice_starts(), 223 slice->slice_limits(), slice->slice_strides())); 224 } 225 226 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { 227 return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( 228 dynamic_slice->operand(0)->shape(), 229 dynamic_slice->operand(1)->shape(), 230 dynamic_slice->dynamic_slice_sizes())); 231 } 232 233 Status ShapeVerifier::HandleDynamicUpdateSlice( 234 HloInstruction* dynamic_update_slice) { 235 return CheckShape(dynamic_update_slice, 236 ShapeInference::InferDynamicUpdateSliceShape( 237 dynamic_update_slice->operand(0)->shape(), 238 dynamic_update_slice->operand(1)->shape(), 239 dynamic_update_slice->operand(2)->shape())); 240 } 241 242 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { 243 return CheckVariadicShape(tuple); 244 } 245 246 Status ShapeVerifier::HandleMap(HloInstruction* map) { 247 std::vector<const Shape*> operand_shapes; 248 int64 max_operand_rank = 0; 249 for (const HloInstruction* operand : map->operands()) { 250 operand_shapes.push_back(&operand->shape()); 251 max_operand_rank = 252 std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); 253 } 254 // TODO(b/65689298) Remove code below once Map is generalized to accept 255 // arbitrary map dimensions. 256 std::vector<int64> map_dims(max_operand_rank); 257 std::iota(map_dims.begin(), map_dims.end(), 0); 258 return CheckShape(map, ShapeInference::InferMapShape( 259 operand_shapes, 260 map->to_apply()->ComputeProgramShape(), map_dims)); 261 } 262 263 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { 264 return CheckShape( 265 reduce_window, 266 ShapeInference::InferReduceWindowShape( 267 reduce_window->operand(0)->shape(), 268 reduce_window->operand(1)->shape(), reduce_window->window(), 269 reduce_window->to_apply()->ComputeProgramShape())); 270 } 271 272 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { 273 return CheckShape( 274 instruction, 275 ShapeInference::InferSelectAndScatterShape( 276 instruction->operand(0)->shape(), 277 instruction->select()->ComputeProgramShape(), instruction->window(), 278 instruction->operand(1)->shape(), instruction->operand(2)->shape(), 279 instruction->scatter()->ComputeProgramShape())); 280 } 281 282 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { 283 // The shape of kWhile should match the shape of the body computation it 284 // calls. 285 return CheckShape(xla_while, 286 xla_while->while_body()->ComputeProgramShape().result()); 287 } 288 289 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { 290 TF_RETURN_IF_ERROR(CheckShape( 291 conditional, 292 conditional->true_computation()->ComputeProgramShape().result())); 293 return CheckShape( 294 conditional, 295 conditional->false_computation()->ComputeProgramShape().result()); 296 } 297 298 Status ShapeVerifier::HandlePad(HloInstruction* pad) { 299 return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), 300 pad->operand(1)->shape(), 301 pad->padding_config())); 302 } 303 304 Status ShapeVerifier::HandleSend(HloInstruction* send) { 305 TF_RET_CHECK(send->users().size() == 1); 306 const HloInstruction* send_done = send->users().front(); 307 TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); 308 TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); 309 return CheckShape( 310 send, ShapeUtil::MakeTupleShape( 311 {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); 312 } 313 314 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { 315 TF_RET_CHECK(send_done->operands().size() == 1); 316 const HloInstruction* send = send_done->operand(0); 317 TF_RET_CHECK(send->opcode() == HloOpcode::kSend); 318 TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); 319 return CheckShape(send_done, ShapeUtil::MakeNil()); 320 } 321 322 Status ShapeVerifier::HandleRecv(HloInstruction* recv) { 323 TF_RET_CHECK(recv->users().size() == 1); 324 const HloInstruction* recv_done = recv->users().front(); 325 TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); 326 TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); 327 return CheckShape(recv, 328 ShapeUtil::MakeTupleShape( 329 {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); 330 } 331 332 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { 333 TF_RET_CHECK(recv_done->operands().size() == 1); 334 const HloInstruction* recv = recv_done->operand(0); 335 TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); 336 TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); 337 return CheckShape(recv_done, recv->shape().tuple_shapes(0)); 338 } 339 340 Status ShapeVerifier::HandleBatchNormTraining( 341 HloInstruction* batch_norm_training) { 342 return CheckShape(batch_norm_training, 343 ShapeInference::InferBatchNormTrainingShape( 344 batch_norm_training->operand(0)->shape(), 345 batch_norm_training->operand(1)->shape(), 346 batch_norm_training->operand(2)->shape(), 347 batch_norm_training->feature_index())); 348 } 349 350 Status ShapeVerifier::HandleBatchNormInference( 351 HloInstruction* batch_norm_inference) { 352 return CheckShape(batch_norm_inference, 353 ShapeInference::InferBatchNormInferenceShape( 354 batch_norm_inference->operand(0)->shape(), 355 batch_norm_inference->operand(1)->shape(), 356 batch_norm_inference->operand(2)->shape(), 357 batch_norm_inference->operand(3)->shape(), 358 batch_norm_inference->operand(4)->shape(), 359 batch_norm_inference->feature_index())); 360 } 361 362 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { 363 return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( 364 batch_norm_grad->operand(0)->shape(), 365 batch_norm_grad->operand(1)->shape(), 366 batch_norm_grad->operand(2)->shape(), 367 batch_norm_grad->operand(3)->shape(), 368 batch_norm_grad->operand(4)->shape(), 369 batch_norm_grad->feature_index())); 370 } 371 372 namespace { 373 374 // Checks that the instruction does not have mixed precision floating point 375 // inputs. 376 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { 377 switch (instruction->opcode()) { 378 // White list the following opcodes for mixed-precision check, because they 379 // involve data pass through or grouping via tuples, where the precisions 380 // of buffers can be different. 381 case HloOpcode::kCall: 382 case HloOpcode::kConditional: 383 case HloOpcode::kConstant: 384 case HloOpcode::kCrossReplicaSum: 385 case HloOpcode::kCustomCall: 386 case HloOpcode::kFusion: 387 case HloOpcode::kGetTupleElement: 388 case HloOpcode::kInfeed: 389 case HloOpcode::kOutfeed: 390 case HloOpcode::kParameter: 391 case HloOpcode::kRecv: 392 case HloOpcode::kRecvDone: 393 case HloOpcode::kReducePrecision: 394 case HloOpcode::kSelect: 395 case HloOpcode::kSend: 396 case HloOpcode::kSendDone: 397 case HloOpcode::kTuple: 398 case HloOpcode::kWhile: 399 break; 400 default: { 401 PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; 402 for (auto operand : instruction->operands()) { 403 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( 404 operand->shape(), 405 [&](const Shape& subshape, const ShapeIndex& index) { 406 if (!ShapeUtil::ElementIsFloating(subshape)) { 407 return Status::OK(); 408 } 409 if (fp_type == PRIMITIVE_TYPE_INVALID) { 410 fp_type = subshape.element_type(); 411 } else if (fp_type != subshape.element_type()) { 412 return FailedPrecondition( 413 "Seen floating point types of different precisions in " 414 "%s, but mixed precision is disallowed.", 415 instruction->ToString().c_str()); 416 } 417 return Status::OK(); 418 })); 419 } 420 } 421 } 422 return Status::OK(); 423 } 424 425 } // namespace 426 427 Status ShapeVerifier::HandleGather(HloInstruction* gather) { 428 return CheckShape( 429 gather, 430 ShapeInference::InferGatherShape( 431 gather->operand(0)->shape(), gather->operand(1)->shape(), 432 gather->gather_dimension_numbers(), gather->gather_window_bounds())); 433 } 434 435 Status ShapeVerifier::CheckShape(const HloInstruction* instruction, 436 const Shape& inferred_shape) { 437 // If allow_mixed_precision_ is false, check if there are operands with 438 // different precisions. We need this check because ShapeInference allows 439 // mixed precision inputs. 440 if (!allow_mixed_precision_) { 441 TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); 442 } 443 444 // Check if the output shape matches the expected shape. 445 bool compatible; 446 // We treat BF16 and F32 as compatible types if mixed precision is allowed, 447 // but only when the instruction defines the BF16/F32 buffer. 448 switch (instruction->opcode()) { 449 case HloOpcode::kSelect: 450 if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { 451 // Select only defines the top-level buffer, which in this case is the 452 // tuple, so we cannot allow mixed precision. 453 compatible = 454 ShapeUtil::Compatible(instruction->shape(), inferred_shape); 455 } else { 456 compatible = ShapeUtil::CompatibleIgnoringFpPrecision( 457 instruction->shape(), inferred_shape); 458 } 459 break; 460 case HloOpcode::kGetTupleElement: 461 case HloOpcode::kTuple: 462 // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed 463 // precision is disallowed. 464 case HloOpcode::kConstant: 465 case HloOpcode::kBitcast: 466 case HloOpcode::kBitcastConvert: 467 case HloOpcode::kCall: 468 case HloOpcode::kConditional: 469 case HloOpcode::kConvert: 470 case HloOpcode::kCustomCall: 471 case HloOpcode::kInfeed: 472 case HloOpcode::kOutfeed: 473 case HloOpcode::kParameter: 474 case HloOpcode::kRecv: 475 case HloOpcode::kRecvDone: 476 case HloOpcode::kSend: 477 case HloOpcode::kSendDone: 478 case HloOpcode::kWhile: 479 // The above opcodes should match the expected shapes exactly. 480 compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); 481 break; 482 default: 483 if (allow_mixed_precision_) { 484 compatible = ShapeUtil::CompatibleIgnoringFpPrecision( 485 instruction->shape(), inferred_shape); 486 } else { 487 compatible = 488 ShapeUtil::Compatible(instruction->shape(), inferred_shape); 489 } 490 } 491 if (!compatible) { 492 return InvalidArgument( 493 "Expected instruction to have shape compatible with %s, actual " 494 "shape is %s:\n%s", 495 ShapeUtil::HumanString(inferred_shape).c_str(), 496 ShapeUtil::HumanString(instruction->shape()).c_str(), 497 instruction->ToString().c_str()); 498 } 499 return tensorflow::Status::OK(); 500 } 501 502 Status ShapeVerifier::CheckShape(const HloInstruction* instruction, 503 const StatusOr<Shape>& inferred_shape_status) { 504 if (!inferred_shape_status.ok()) { 505 Status s = inferred_shape_status.status(); 506 tensorflow::errors::AppendToMessage(&s, ", for instruction ", 507 instruction->ToString()); 508 return s; 509 } 510 return CheckShape(instruction, inferred_shape_status.ValueOrDie()); 511 } 512 513 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { 514 return CheckShape(instruction, 515 ShapeInference::InferUnaryOpShape(instruction->opcode(), 516 instruction->operand(0))); 517 } 518 519 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { 520 return CheckShape( 521 instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), 522 instruction->operand(0), 523 instruction->operand(1))); 524 } 525 526 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { 527 return CheckShape(instruction, 528 ShapeInference::InferTernaryOpShape( 529 instruction->opcode(), instruction->operand(0), 530 instruction->operand(1), instruction->operand(2))); 531 } 532 533 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { 534 return CheckShape(instruction, 535 ShapeInference::InferVariadicOpShape( 536 instruction->opcode(), instruction->operands())); 537 } 538 539 // Checks if the given two instructions shares the same channel id. 540 Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, 541 const HloInstruction* instr2) { 542 if (instr1->channel_id() != instr2->channel_id()) { 543 return FailedPrecondition( 544 "Expected to have the same channel id, actual channel ids are: %s " 545 "(%lld), %s (%lld)", 546 instr1->ToString().c_str(), instr1->channel_id(), 547 instr2->ToString().c_str(), instr2->channel_id()); 548 } 549 return tensorflow::Status::OK(); 550 } 551 552 string ComputationsToString( 553 tensorflow::gtl::ArraySlice<HloComputation*> computations) { 554 return tensorflow::str_util::Join( 555 computations, ",", [](string* s, const HloComputation* computation) { 556 s->append(computation->name()); 557 }); 558 } 559 560 // Verifies various invariants about the structure of the HLO: 561 // 562 // (1) each instruction has a non-null parent() set to the HloComputation which 563 // contains it. 564 // 565 // (2) each computation has a non-null parent() set to the HloModule which 566 // contains it. 567 // 568 // (3) the operands of each instruction are in the same computation as the 569 // instruction. 570 Status VerifyHloStructure(HloModule* module) { 571 for (const HloComputation* computation : module->computations()) { 572 if (computation->parent() == nullptr) { 573 return FailedPrecondition("Computation %s has a null parent pointer", 574 computation->name().c_str()); 575 } 576 if (computation->parent() != module) { 577 return FailedPrecondition( 578 "Computation %s parent() does not point to parent module", 579 computation->name().c_str()); 580 } 581 582 for (const HloInstruction* instruction : computation->instructions()) { 583 if (instruction->parent() == nullptr) { 584 return FailedPrecondition("Instruction %s has a null parent pointer", 585 instruction->name().c_str()); 586 } 587 if (instruction->parent() != computation) { 588 return FailedPrecondition( 589 "Instruction %s parent() does not point to parent computation", 590 instruction->name().c_str()); 591 } 592 } 593 } 594 595 // Check that operands are in the same computation separately from verifying 596 // parent() correctness so conditions like a null HloInstruction::parent() are 597 // identified and reported explicitly above rather than reporting a mismatched 598 // operand. 599 for (const HloComputation* computation : module->computations()) { 600 for (const HloInstruction* instruction : computation->instructions()) { 601 for (int i = 0; i < instruction->operand_count(); ++i) { 602 const HloInstruction* operand = instruction->operand(i); 603 if (operand->parent() != instruction->parent()) { 604 return FailedPrecondition( 605 "Operand %d (%s) of instruction %s is in a different " 606 "computation: %s vs %s", 607 i, operand->name().c_str(), instruction->name().c_str(), 608 operand->parent()->name().c_str(), 609 instruction->parent()->name().c_str()); 610 } 611 } 612 } 613 } 614 return tensorflow::Status::OK(); 615 } 616 617 Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { 618 // The parent fusion instruction of the fusion computation must be 'fusion'. 619 HloComputation* fused_computation = fusion->fused_instructions_computation(); 620 if (fusion != fused_computation->FusionInstruction()) { 621 return FailedPrecondition( 622 "Instruction of fused computation does not match expected instruction " 623 "%s.", 624 fusion->ToString().c_str()); 625 } 626 627 // Fused root instruction and fused parameters must all be owned by the fusion 628 // computation. 629 bool root_owned = false; 630 const std::vector<HloInstruction*>& fused_parameters = 631 fusion->fused_parameters(); 632 const HloInstruction* fused_root = fusion->fused_expression_root(); 633 std::vector<bool> parameter_owned(fused_parameters.size(), false); 634 for (auto* instruction : fused_computation->instructions()) { 635 if (fused_root == instruction) { 636 if (root_owned) { 637 return FailedPrecondition("Root appears more than once in %s.", 638 fusion->ToString().c_str()); 639 } 640 root_owned = true; 641 } 642 for (int i = 0; i < fused_parameters.size(); ++i) { 643 if (fused_parameters[i] == instruction) { 644 if (parameter_owned[i]) { 645 return FailedPrecondition("Parameter appears more than once in %s.", 646 fusion->ToString().c_str()); 647 } 648 parameter_owned[i] = true; 649 } 650 } 651 } 652 if (!root_owned) { 653 return FailedPrecondition("Root not found in computation of %s.", 654 fusion->ToString().c_str()); 655 } 656 // Make sure all the parameter_owned entries are set 657 for (int i = 0; i < parameter_owned.size(); i++) { 658 if (!parameter_owned[i]) { 659 return FailedPrecondition("Parameter %d not found in computation of %s.", 660 i, fusion->ToString().c_str()); 661 } 662 } 663 664 // Fused root must have no users. 665 if (fused_root->user_count() != 0) { 666 return FailedPrecondition("Root of %s may not have users.", 667 fusion->ToString().c_str()); 668 } 669 670 // All uses of fused instructions must be in the fusion computation, and every 671 // non-root instruction must have at least one use. 672 for (auto* instruction : 673 fusion->fused_instructions_computation()->instructions()) { 674 if (instruction != fused_root) { 675 if (instruction->user_count() == 0) { 676 return FailedPrecondition( 677 "Non-root instruction %s in %s must have users.", 678 instruction->ToString().c_str(), fusion->ToString().c_str()); 679 } 680 for (auto& user : instruction->users()) { 681 if (fused_computation != user->parent()) { 682 return FailedPrecondition( 683 "Non-root instruction %s in %s may not have external users.", 684 instruction->ToString().c_str(), fusion->ToString().c_str()); 685 } 686 } 687 } 688 } 689 690 // Fused parameter instructions must be numbered contiguously and match up 691 // (shapes compatible) with their respective operand. 692 CHECK_EQ(fusion->operands().size(), fused_parameters.size()); 693 std::vector<bool> parameter_numbers(fused_parameters.size(), false); 694 for (auto fused_param : fused_parameters) { 695 int64 param_no = fused_param->parameter_number(); 696 if (param_no < 0) { 697 return FailedPrecondition( 698 "Unexpected negative parameter number %lld in %s.", param_no, 699 fusion->ToString().c_str()); 700 } 701 if (param_no >= fused_parameters.size()) { 702 return FailedPrecondition( 703 "Unexpected parameter number %lld in %s: higher then number of " 704 "parameters %lu.", 705 param_no, fusion->ToString().c_str(), fused_parameters.size()); 706 } 707 if (parameter_numbers[param_no]) { 708 return FailedPrecondition( 709 "Did not expect parameter number %lld more than once in %s.", 710 param_no, fusion->ToString().c_str()); 711 } 712 parameter_numbers[param_no] = true; 713 if (!ShapeUtil::Compatible(fused_param->shape(), 714 fusion->operand(param_no)->shape())) { 715 return FailedPrecondition( 716 "Shape mismatch between parameter number %lld and its operand in %s.", 717 param_no, fusion->ToString().c_str()); 718 } 719 } 720 // Make sure all the parameter_numbers entries were seen 721 for (int i = 0; i < parameter_numbers.size(); i++) { 722 if (!parameter_numbers[i]) { 723 return FailedPrecondition("Did not see parameter number %d in %s.", i, 724 fusion->ToString().c_str()); 725 } 726 } 727 728 // TODO(b/65423525): We'd like to check that all operands are distinct. 729 // This is currently disabled due to the invariant being violated by 730 // multi-output fusion. 731 return tensorflow::Status::OK(); 732 } 733 734 StatusOr<bool> HloVerifier::Run(HloModule* module) { 735 TF_RETURN_IF_ERROR(VerifyHloStructure(module)); 736 737 tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions; 738 739 for (auto* computation : module->computations()) { 740 for (const auto& instruction : computation->instructions()) { 741 TF_RET_CHECK(instruction->parent() == computation); 742 if (instruction->opcode() == HloOpcode::kFusion) { 743 TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); 744 TF_RET_CHECK( 745 ContainersEqual(instruction->called_computations(), 746 {instruction->fused_instructions_computation()})) 747 << "Fusion HLO calls computations other than the " 748 "fused_instructions_computation: " 749 << instruction->ToString() 750 << " instruction->fused_instructions_computation(): " 751 << instruction->fused_instructions_computation()->ToString() 752 << " instruction->called_computations(): " 753 << ComputationsToString(instruction->called_computations()); 754 755 for (const auto& fused : instruction->fused_instructions()) { 756 TF_RET_CHECK(fused->parent() == 757 instruction->fused_instructions_computation()) 758 << "Fused HLO was missing a parent: " << fused->ToString() 759 << " parent: " << fused->parent() 760 << " computation: " << computation; 761 } 762 } else if (instruction->opcode() == HloOpcode::kBroadcast) { 763 // If you see this failure then someone has confused the difference 764 // between the HLO broadcast op, and the UserComputation broadcast 765 // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I 766 // or ComputationLowerer::Visit() 767 TF_RET_CHECK(instruction->dimensions().size() == 768 ShapeUtil::Rank(instruction->operand(0)->shape())) 769 << "Broadcast HLO has invalid number of dimensions."; 770 } else if (instruction->opcode() == HloOpcode::kWhile) { 771 auto* while_cond = instruction->while_condition(); 772 auto* while_body = instruction->while_body(); 773 TF_RET_CHECK(while_cond->num_parameters() == 1) 774 << "While condition must have exactly 1 parameter; had " 775 << while_cond->num_parameters() << ": " << while_cond->ToString(); 776 TF_RET_CHECK(while_body->num_parameters() == 1) 777 << "While body must have exactly 1 parameter; had " 778 << while_body->num_parameters() << ": " << while_body->ToString(); 779 TF_RET_CHECK(instruction->operand_count() == 1) 780 << "While loop must have exactly one operand; had " 781 << instruction->operand_count() << ": " << instruction->ToString(); 782 783 auto* init = instruction->operand(0); 784 auto* cond_param = while_cond->parameter_instruction(0); 785 TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape())) 786 << "While condition's parameter must have the same shape as the " 787 "loop's 'init'. init: " 788 << init->ToString() << ", param: " << cond_param->ToString(); 789 auto* cond_root = while_cond->root_instruction(); 790 TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(), 791 ShapeUtil::MakeShape(PRED, {}))) 792 << "While condition should have shape PRED: " 793 << cond_root->ToString(); 794 795 auto* body_param = while_body->parameter_instruction(0); 796 TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape())) 797 << "While body's parameter must have the same shape as the loop's " 798 "'init'. init: " 799 << init->ToString() << ", param: " << body_param->ToString(); 800 auto* body_root = while_body->root_instruction(); 801 TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape())) 802 << "While body should have same shape as the loop's 'init'. init: " 803 << init->ToString() << ", body: " << body_root->ToString(); 804 } 805 806 auto previous = instructions.find(instruction->name()); 807 TF_RET_CHECK(previous == instructions.end()) 808 << "HLO has name that is not unique within module:\n" 809 << instruction->ToString() 810 << " in computation: " << computation->name() 811 << "\nPrevious HLO with same name:\n" 812 << previous->second->ToString() 813 << " in computation: " << previous->second->parent()->name(); 814 instructions[instruction->name()] = instruction; 815 } 816 817 std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_(); 818 TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); 819 } 820 821 return false; 822 } 823 824 } // namespace xla 825