Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <set>
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
     19 #include "tensorflow/compiler/xla/status_macros.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 #include "tensorflow/core/lib/gtl/flatmap.h"
     22 
     23 namespace xla {
     24 
     25 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
     26   return CheckUnaryShape(hlo);
     27 }
     28 
     29 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
     30   return CheckBinaryShape(hlo);
     31 }
     32 
     33 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
     34   return CheckTernaryShape(clamp);
     35 }
     36 
     37 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
     38   return CheckTernaryShape(select);
     39 }
     40 
     41 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
     42   std::vector<const Shape*> operand_shapes;
     43   for (const HloInstruction* operand : concatenate->operands()) {
     44     operand_shapes.push_back(&operand->shape());
     45   }
     46   return CheckShape(concatenate,
     47                     ShapeInference::InferConcatOpShape(
     48                         operand_shapes, concatenate->concatenate_dimension()));
     49 }
     50 
     51 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
     52   return CheckShape(convert, ShapeInference::InferConvertShape(
     53                                  convert->operand(0)->shape(),
     54                                  convert->shape().element_type()));
     55 }
     56 
     57 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
     58   return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
     59                                  convert->operand(0)->shape(),
     60                                  convert->shape().element_type()));
     61 }
     62 
     63 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
     64   return CheckUnaryShape(copy);
     65 }
     66 
     67 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
     68   TF_ASSIGN_OR_RETURN(const Shape expected,
     69                       ShapeInference::InferDotOpShape(
     70                           dot->operand(0)->shape(), dot->operand(1)->shape(),
     71                           dot->dot_dimension_numbers()));
     72   return CheckShape(dot, expected);
     73 }
     74 
     75 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
     76   TF_ASSIGN_OR_RETURN(
     77       const Shape expected,
     78       ShapeInference::InferConvolveShape(
     79           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
     80           convolution->window(), convolution->convolution_dimension_numbers()));
     81   return CheckShape(convolution, expected);
     82 }
     83 
     84 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
     85   TF_ASSIGN_OR_RETURN(
     86       const Shape expected,
     87       ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
     88                                     fft->fft_length()));
     89   return CheckShape(fft, expected);
     90 }
     91 
     92 Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) {
     93   std::vector<const Shape*> operand_shapes;
     94   for (const HloInstruction* operand : crs->operands()) {
     95     operand_shapes.push_back(&operand->shape());
     96   }
     97   return CheckShape(crs,
     98                     ShapeInference::InferCrossReplicaSumShape(operand_shapes));
     99 }
    100 
    101 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
    102   return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
    103                                           reduce_precision->operand(0)->shape(),
    104                                           reduce_precision->exponent_bits(),
    105                                           reduce_precision->mantissa_bits()));
    106 }
    107 
    108 Status ShapeVerifier::HandleInfeed(HloInstruction*) {
    109   return tensorflow::Status::OK();
    110 }
    111 
    112 Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
    113   // Outfeed has a separate shape field for the value which is outfed to the
    114   // host. The shape of the instruction itself is always nil because the outfeed
    115   // produces no HLO value in the graph.
    116   if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
    117                              outfeed->operand(0)->shape())) {
    118     return InvalidArgument(
    119         "Expected outfeed to have shape compatible with operand's shape %s, "
    120         "actual shape is %s:\n%s",
    121         ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
    122         ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
    123         outfeed->ToString().c_str());
    124   }
    125   return CheckShape(outfeed, ShapeUtil::MakeNil());
    126 }
    127 
    128 Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
    129   return tensorflow::Status::OK();
    130 }
    131 
    132 Status ShapeVerifier::HandleRng(HloInstruction*) {
    133   return tensorflow::Status::OK();
    134 }
    135 
    136 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
    137   return CheckShape(
    138       reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
    139                                                  reverse->dimensions()));
    140 }
    141 
    142 Status ShapeVerifier::HandleSort(HloInstruction* sort) {
    143   return CheckUnaryShape(sort);
    144 }
    145 
    146 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
    147   return CheckShape(constant, constant->literal().shape());
    148 }
    149 
    150 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
    151   return CheckShape(get_tuple_element,
    152                     ShapeInference::InferGetTupleElementShape(
    153                         get_tuple_element->operand(0)->shape(),
    154                         get_tuple_element->tuple_index()));
    155 }
    156 
    157 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
    158   return CheckShape(
    159       reduce,
    160       ShapeInference::InferReduceShape(
    161           reduce->operand(0)->shape(), reduce->operand(1)->shape(),
    162           reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
    163 }
    164 
    165 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
    166   return tensorflow::Status::OK();
    167 }
    168 
    169 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
    170   // HLO broadcast has no exact analog at the proto level so there is no
    171   // ShapeInference method. Check the output shape explicitly.
    172   const Shape& operand_shape = broadcast->operand(0)->shape();
    173   // Check for mixed precision.
    174   TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
    175   TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
    176                broadcast->dimensions().size());
    177   for (int64 operand_dimension = 0;
    178        operand_dimension < ShapeUtil::Rank(operand_shape);
    179        ++operand_dimension) {
    180     int64 output_dimension = broadcast->dimensions()[operand_dimension];
    181     TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
    182                  operand_shape.dimensions(operand_dimension))
    183         << broadcast->ToString() << " operand shape " << operand_shape;
    184   }
    185   return tensorflow::Status::OK();
    186 }
    187 
    188 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
    189   // Check for mixed precision.
    190   TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
    191   TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
    192                ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
    193   return tensorflow::Status::OK();
    194 }
    195 
    196 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
    197   return CheckShape(
    198       transpose, ShapeInference::InferTransposeShape(
    199                      transpose->operand(0)->shape(), transpose->dimensions()));
    200 }
    201 
    202 Status ShapeVerifier::HandleParameter(HloInstruction*) {
    203   return tensorflow::Status::OK();
    204 }
    205 
    206 Status ShapeVerifier::HandleFusion(HloInstruction*) {
    207   return tensorflow::Status::OK();
    208 }
    209 
    210 Status ShapeVerifier::HandleCall(HloInstruction* call) {
    211   // The shape of kCall should match the shape of the computation it calls.
    212   return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
    213 }
    214 
    215 Status ShapeVerifier::HandleCustomCall(HloInstruction*) {
    216   return tensorflow::Status::OK();
    217 }
    218 
    219 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
    220   return CheckShape(slice,
    221                     ShapeInference::InferSliceShape(
    222                         slice->operand(0)->shape(), slice->slice_starts(),
    223                         slice->slice_limits(), slice->slice_strides()));
    224 }
    225 
    226 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
    227   return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape(
    228                                        dynamic_slice->operand(0)->shape(),
    229                                        dynamic_slice->operand(1)->shape(),
    230                                        dynamic_slice->dynamic_slice_sizes()));
    231 }
    232 
    233 Status ShapeVerifier::HandleDynamicUpdateSlice(
    234     HloInstruction* dynamic_update_slice) {
    235   return CheckShape(dynamic_update_slice,
    236                     ShapeInference::InferDynamicUpdateSliceShape(
    237                         dynamic_update_slice->operand(0)->shape(),
    238                         dynamic_update_slice->operand(1)->shape(),
    239                         dynamic_update_slice->operand(2)->shape()));
    240 }
    241 
    242 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
    243   return CheckVariadicShape(tuple);
    244 }
    245 
    246 Status ShapeVerifier::HandleMap(HloInstruction* map) {
    247   std::vector<const Shape*> operand_shapes;
    248   int64 max_operand_rank = 0;
    249   for (const HloInstruction* operand : map->operands()) {
    250     operand_shapes.push_back(&operand->shape());
    251     max_operand_rank =
    252         std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
    253   }
    254   // TODO(b/65689298) Remove code below once Map is generalized to accept
    255   // arbitrary map dimensions.
    256   std::vector<int64> map_dims(max_operand_rank);
    257   std::iota(map_dims.begin(), map_dims.end(), 0);
    258   return CheckShape(map, ShapeInference::InferMapShape(
    259                              operand_shapes,
    260                              map->to_apply()->ComputeProgramShape(), map_dims));
    261 }
    262 
    263 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
    264   return CheckShape(
    265       reduce_window,
    266       ShapeInference::InferReduceWindowShape(
    267           reduce_window->operand(0)->shape(),
    268           reduce_window->operand(1)->shape(), reduce_window->window(),
    269           reduce_window->to_apply()->ComputeProgramShape()));
    270 }
    271 
    272 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
    273   return CheckShape(
    274       instruction,
    275       ShapeInference::InferSelectAndScatterShape(
    276           instruction->operand(0)->shape(),
    277           instruction->select()->ComputeProgramShape(), instruction->window(),
    278           instruction->operand(1)->shape(), instruction->operand(2)->shape(),
    279           instruction->scatter()->ComputeProgramShape()));
    280 }
    281 
    282 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
    283   // The shape of kWhile should match the shape of the body computation it
    284   // calls.
    285   return CheckShape(xla_while,
    286                     xla_while->while_body()->ComputeProgramShape().result());
    287 }
    288 
    289 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
    290   TF_RETURN_IF_ERROR(CheckShape(
    291       conditional,
    292       conditional->true_computation()->ComputeProgramShape().result()));
    293   return CheckShape(
    294       conditional,
    295       conditional->false_computation()->ComputeProgramShape().result());
    296 }
    297 
    298 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
    299   return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
    300                                                        pad->operand(1)->shape(),
    301                                                        pad->padding_config()));
    302 }
    303 
    304 Status ShapeVerifier::HandleSend(HloInstruction* send) {
    305   TF_RET_CHECK(send->users().size() == 1);
    306   const HloInstruction* send_done = send->users().front();
    307   TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
    308   TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
    309   return CheckShape(
    310       send, ShapeUtil::MakeTupleShape(
    311                 {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
    312 }
    313 
    314 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
    315   TF_RET_CHECK(send_done->operands().size() == 1);
    316   const HloInstruction* send = send_done->operand(0);
    317   TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
    318   TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
    319   return CheckShape(send_done, ShapeUtil::MakeNil());
    320 }
    321 
    322 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
    323   TF_RET_CHECK(recv->users().size() == 1);
    324   const HloInstruction* recv_done = recv->users().front();
    325   TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
    326   TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
    327   return CheckShape(recv,
    328                     ShapeUtil::MakeTupleShape(
    329                         {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
    330 }
    331 
    332 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
    333   TF_RET_CHECK(recv_done->operands().size() == 1);
    334   const HloInstruction* recv = recv_done->operand(0);
    335   TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
    336   TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
    337   return CheckShape(recv_done, recv->shape().tuple_shapes(0));
    338 }
    339 
    340 Status ShapeVerifier::HandleBatchNormTraining(
    341     HloInstruction* batch_norm_training) {
    342   return CheckShape(batch_norm_training,
    343                     ShapeInference::InferBatchNormTrainingShape(
    344                         batch_norm_training->operand(0)->shape(),
    345                         batch_norm_training->operand(1)->shape(),
    346                         batch_norm_training->operand(2)->shape(),
    347                         batch_norm_training->feature_index()));
    348 }
    349 
    350 Status ShapeVerifier::HandleBatchNormInference(
    351     HloInstruction* batch_norm_inference) {
    352   return CheckShape(batch_norm_inference,
    353                     ShapeInference::InferBatchNormInferenceShape(
    354                         batch_norm_inference->operand(0)->shape(),
    355                         batch_norm_inference->operand(1)->shape(),
    356                         batch_norm_inference->operand(2)->shape(),
    357                         batch_norm_inference->operand(3)->shape(),
    358                         batch_norm_inference->operand(4)->shape(),
    359                         batch_norm_inference->feature_index()));
    360 }
    361 
    362 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
    363   return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
    364                                          batch_norm_grad->operand(0)->shape(),
    365                                          batch_norm_grad->operand(1)->shape(),
    366                                          batch_norm_grad->operand(2)->shape(),
    367                                          batch_norm_grad->operand(3)->shape(),
    368                                          batch_norm_grad->operand(4)->shape(),
    369                                          batch_norm_grad->feature_index()));
    370 }
    371 
    372 namespace {
    373 
    374 // Checks that the instruction does not have mixed precision floating point
    375 // inputs.
    376 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
    377   switch (instruction->opcode()) {
    378     // White list the following opcodes for mixed-precision check, because they
    379     // involve data pass through or grouping via tuples, where the precisions
    380     // of buffers can be different.
    381     case HloOpcode::kCall:
    382     case HloOpcode::kConditional:
    383     case HloOpcode::kConstant:
    384     case HloOpcode::kCrossReplicaSum:
    385     case HloOpcode::kCustomCall:
    386     case HloOpcode::kFusion:
    387     case HloOpcode::kGetTupleElement:
    388     case HloOpcode::kInfeed:
    389     case HloOpcode::kOutfeed:
    390     case HloOpcode::kParameter:
    391     case HloOpcode::kRecv:
    392     case HloOpcode::kRecvDone:
    393     case HloOpcode::kReducePrecision:
    394     case HloOpcode::kSelect:
    395     case HloOpcode::kSend:
    396     case HloOpcode::kSendDone:
    397     case HloOpcode::kTuple:
    398     case HloOpcode::kWhile:
    399       break;
    400     default: {
    401       PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
    402       for (auto operand : instruction->operands()) {
    403         TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
    404             operand->shape(),
    405             [&](const Shape& subshape, const ShapeIndex& index) {
    406               if (!ShapeUtil::ElementIsFloating(subshape)) {
    407                 return Status::OK();
    408               }
    409               if (fp_type == PRIMITIVE_TYPE_INVALID) {
    410                 fp_type = subshape.element_type();
    411               } else if (fp_type != subshape.element_type()) {
    412                 return FailedPrecondition(
    413                     "Seen floating point types of different precisions in "
    414                     "%s, but mixed precision is disallowed.",
    415                     instruction->ToString().c_str());
    416               }
    417               return Status::OK();
    418             }));
    419       }
    420     }
    421   }
    422   return Status::OK();
    423 }
    424 
    425 }  // namespace
    426 
    427 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
    428   return CheckShape(
    429       gather,
    430       ShapeInference::InferGatherShape(
    431           gather->operand(0)->shape(), gather->operand(1)->shape(),
    432           gather->gather_dimension_numbers(), gather->gather_window_bounds()));
    433 }
    434 
    435 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
    436                                  const Shape& inferred_shape) {
    437   // If allow_mixed_precision_ is false, check if there are operands with
    438   // different precisions. We need this check because ShapeInference allows
    439   // mixed precision inputs.
    440   if (!allow_mixed_precision_) {
    441     TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
    442   }
    443 
    444   // Check if the output shape matches the expected shape.
    445   bool compatible;
    446   // We treat BF16 and F32 as compatible types if mixed precision is allowed,
    447   // but only when the instruction defines the BF16/F32 buffer.
    448   switch (instruction->opcode()) {
    449     case HloOpcode::kSelect:
    450       if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
    451         // Select only defines the top-level buffer, which in this case is the
    452         // tuple, so we cannot allow mixed precision.
    453         compatible =
    454             ShapeUtil::Compatible(instruction->shape(), inferred_shape);
    455       } else {
    456         compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
    457             instruction->shape(), inferred_shape);
    458       }
    459       break;
    460     case HloOpcode::kGetTupleElement:
    461     case HloOpcode::kTuple:
    462       // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
    463       // precision is disallowed.
    464     case HloOpcode::kConstant:
    465     case HloOpcode::kBitcast:
    466     case HloOpcode::kBitcastConvert:
    467     case HloOpcode::kCall:
    468     case HloOpcode::kConditional:
    469     case HloOpcode::kConvert:
    470     case HloOpcode::kCustomCall:
    471     case HloOpcode::kInfeed:
    472     case HloOpcode::kOutfeed:
    473     case HloOpcode::kParameter:
    474     case HloOpcode::kRecv:
    475     case HloOpcode::kRecvDone:
    476     case HloOpcode::kSend:
    477     case HloOpcode::kSendDone:
    478     case HloOpcode::kWhile:
    479       // The above opcodes should match the expected shapes exactly.
    480       compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
    481       break;
    482     default:
    483       if (allow_mixed_precision_) {
    484         compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
    485             instruction->shape(), inferred_shape);
    486       } else {
    487         compatible =
    488             ShapeUtil::Compatible(instruction->shape(), inferred_shape);
    489       }
    490   }
    491   if (!compatible) {
    492     return InvalidArgument(
    493         "Expected instruction to have shape compatible with %s, actual "
    494         "shape is %s:\n%s",
    495         ShapeUtil::HumanString(inferred_shape).c_str(),
    496         ShapeUtil::HumanString(instruction->shape()).c_str(),
    497         instruction->ToString().c_str());
    498   }
    499   return tensorflow::Status::OK();
    500 }
    501 
    502 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
    503                                  const StatusOr<Shape>& inferred_shape_status) {
    504   if (!inferred_shape_status.ok()) {
    505     Status s = inferred_shape_status.status();
    506     tensorflow::errors::AppendToMessage(&s, ", for instruction ",
    507                                         instruction->ToString());
    508     return s;
    509   }
    510   return CheckShape(instruction, inferred_shape_status.ValueOrDie());
    511 }
    512 
    513 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
    514   return CheckShape(instruction,
    515                     ShapeInference::InferUnaryOpShape(instruction->opcode(),
    516                                                       instruction->operand(0)));
    517 }
    518 
    519 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
    520   return CheckShape(
    521       instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
    522                                                       instruction->operand(0),
    523                                                       instruction->operand(1)));
    524 }
    525 
    526 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
    527   return CheckShape(instruction,
    528                     ShapeInference::InferTernaryOpShape(
    529                         instruction->opcode(), instruction->operand(0),
    530                         instruction->operand(1), instruction->operand(2)));
    531 }
    532 
    533 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
    534   return CheckShape(instruction,
    535                     ShapeInference::InferVariadicOpShape(
    536                         instruction->opcode(), instruction->operands()));
    537 }
    538 
    539 // Checks if the given two instructions shares the same channel id.
    540 Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
    541                                        const HloInstruction* instr2) {
    542   if (instr1->channel_id() != instr2->channel_id()) {
    543     return FailedPrecondition(
    544         "Expected to have the same channel id, actual channel ids are: %s "
    545         "(%lld), %s (%lld)",
    546         instr1->ToString().c_str(), instr1->channel_id(),
    547         instr2->ToString().c_str(), instr2->channel_id());
    548   }
    549   return tensorflow::Status::OK();
    550 }
    551 
    552 string ComputationsToString(
    553     tensorflow::gtl::ArraySlice<HloComputation*> computations) {
    554   return tensorflow::str_util::Join(
    555       computations, ",", [](string* s, const HloComputation* computation) {
    556         s->append(computation->name());
    557       });
    558 }
    559 
    560 // Verifies various invariants about the structure of the HLO:
    561 //
    562 // (1) each instruction has a non-null parent() set to the HloComputation which
    563 //     contains it.
    564 //
    565 // (2) each computation has a non-null parent() set to the HloModule which
    566 //     contains it.
    567 //
    568 // (3) the operands of each instruction are in the same computation as the
    569 //     instruction.
    570 Status VerifyHloStructure(HloModule* module) {
    571   for (const HloComputation* computation : module->computations()) {
    572     if (computation->parent() == nullptr) {
    573       return FailedPrecondition("Computation %s has a null parent pointer",
    574                                 computation->name().c_str());
    575     }
    576     if (computation->parent() != module) {
    577       return FailedPrecondition(
    578           "Computation %s parent() does not point to parent module",
    579           computation->name().c_str());
    580     }
    581 
    582     for (const HloInstruction* instruction : computation->instructions()) {
    583       if (instruction->parent() == nullptr) {
    584         return FailedPrecondition("Instruction %s has a null parent pointer",
    585                                   instruction->name().c_str());
    586       }
    587       if (instruction->parent() != computation) {
    588         return FailedPrecondition(
    589             "Instruction %s parent() does not point to parent computation",
    590             instruction->name().c_str());
    591       }
    592     }
    593   }
    594 
    595   // Check that operands are in the same computation separately from verifying
    596   // parent() correctness so conditions like a null HloInstruction::parent() are
    597   // identified and reported explicitly above rather than reporting a mismatched
    598   // operand.
    599   for (const HloComputation* computation : module->computations()) {
    600     for (const HloInstruction* instruction : computation->instructions()) {
    601       for (int i = 0; i < instruction->operand_count(); ++i) {
    602         const HloInstruction* operand = instruction->operand(i);
    603         if (operand->parent() != instruction->parent()) {
    604           return FailedPrecondition(
    605               "Operand %d (%s) of instruction %s is in a different "
    606               "computation: %s vs %s",
    607               i, operand->name().c_str(), instruction->name().c_str(),
    608               operand->parent()->name().c_str(),
    609               instruction->parent()->name().c_str());
    610         }
    611       }
    612     }
    613   }
    614   return tensorflow::Status::OK();
    615 }
    616 
    617 Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
    618   // The parent fusion instruction of the fusion computation must be 'fusion'.
    619   HloComputation* fused_computation = fusion->fused_instructions_computation();
    620   if (fusion != fused_computation->FusionInstruction()) {
    621     return FailedPrecondition(
    622         "Instruction of fused computation does not match expected instruction "
    623         "%s.",
    624         fusion->ToString().c_str());
    625   }
    626 
    627   // Fused root instruction and fused parameters must all be owned by the fusion
    628   // computation.
    629   bool root_owned = false;
    630   const std::vector<HloInstruction*>& fused_parameters =
    631       fusion->fused_parameters();
    632   const HloInstruction* fused_root = fusion->fused_expression_root();
    633   std::vector<bool> parameter_owned(fused_parameters.size(), false);
    634   for (auto* instruction : fused_computation->instructions()) {
    635     if (fused_root == instruction) {
    636       if (root_owned) {
    637         return FailedPrecondition("Root appears more than once in %s.",
    638                                   fusion->ToString().c_str());
    639       }
    640       root_owned = true;
    641     }
    642     for (int i = 0; i < fused_parameters.size(); ++i) {
    643       if (fused_parameters[i] == instruction) {
    644         if (parameter_owned[i]) {
    645           return FailedPrecondition("Parameter appears more than once in %s.",
    646                                     fusion->ToString().c_str());
    647         }
    648         parameter_owned[i] = true;
    649       }
    650     }
    651   }
    652   if (!root_owned) {
    653     return FailedPrecondition("Root not found in computation of %s.",
    654                               fusion->ToString().c_str());
    655   }
    656   // Make sure all the parameter_owned entries are set
    657   for (int i = 0; i < parameter_owned.size(); i++) {
    658     if (!parameter_owned[i]) {
    659       return FailedPrecondition("Parameter %d not found in computation of %s.",
    660                                 i, fusion->ToString().c_str());
    661     }
    662   }
    663 
    664   // Fused root must have no users.
    665   if (fused_root->user_count() != 0) {
    666     return FailedPrecondition("Root of %s may not have users.",
    667                               fusion->ToString().c_str());
    668   }
    669 
    670   // All uses of fused instructions must be in the fusion computation, and every
    671   // non-root instruction must have at least one use.
    672   for (auto* instruction :
    673        fusion->fused_instructions_computation()->instructions()) {
    674     if (instruction != fused_root) {
    675       if (instruction->user_count() == 0) {
    676         return FailedPrecondition(
    677             "Non-root instruction %s in %s must have users.",
    678             instruction->ToString().c_str(), fusion->ToString().c_str());
    679       }
    680       for (auto& user : instruction->users()) {
    681         if (fused_computation != user->parent()) {
    682           return FailedPrecondition(
    683               "Non-root instruction %s in %s may not have external users.",
    684               instruction->ToString().c_str(), fusion->ToString().c_str());
    685         }
    686       }
    687     }
    688   }
    689 
    690   // Fused parameter instructions must be numbered contiguously and match up
    691   // (shapes compatible) with their respective operand.
    692   CHECK_EQ(fusion->operands().size(), fused_parameters.size());
    693   std::vector<bool> parameter_numbers(fused_parameters.size(), false);
    694   for (auto fused_param : fused_parameters) {
    695     int64 param_no = fused_param->parameter_number();
    696     if (param_no < 0) {
    697       return FailedPrecondition(
    698           "Unexpected negative parameter number %lld in %s.", param_no,
    699           fusion->ToString().c_str());
    700     }
    701     if (param_no >= fused_parameters.size()) {
    702       return FailedPrecondition(
    703           "Unexpected parameter number %lld in %s: higher then number of "
    704           "parameters %lu.",
    705           param_no, fusion->ToString().c_str(), fused_parameters.size());
    706     }
    707     if (parameter_numbers[param_no]) {
    708       return FailedPrecondition(
    709           "Did not expect parameter number %lld more than once in %s.",
    710           param_no, fusion->ToString().c_str());
    711     }
    712     parameter_numbers[param_no] = true;
    713     if (!ShapeUtil::Compatible(fused_param->shape(),
    714                                fusion->operand(param_no)->shape())) {
    715       return FailedPrecondition(
    716           "Shape mismatch between parameter number %lld and its operand in %s.",
    717           param_no, fusion->ToString().c_str());
    718     }
    719   }
    720   // Make sure all the parameter_numbers entries were seen
    721   for (int i = 0; i < parameter_numbers.size(); i++) {
    722     if (!parameter_numbers[i]) {
    723       return FailedPrecondition("Did not see parameter number %d in %s.", i,
    724                                 fusion->ToString().c_str());
    725     }
    726   }
    727 
    728   // TODO(b/65423525): We'd like to check that all operands are distinct.
    729   // This is currently disabled due to the invariant being violated by
    730   // multi-output fusion.
    731   return tensorflow::Status::OK();
    732 }
    733 
    734 StatusOr<bool> HloVerifier::Run(HloModule* module) {
    735   TF_RETURN_IF_ERROR(VerifyHloStructure(module));
    736 
    737   tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions;
    738 
    739   for (auto* computation : module->computations()) {
    740     for (const auto& instruction : computation->instructions()) {
    741       TF_RET_CHECK(instruction->parent() == computation);
    742       if (instruction->opcode() == HloOpcode::kFusion) {
    743         TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
    744         TF_RET_CHECK(
    745             ContainersEqual(instruction->called_computations(),
    746                             {instruction->fused_instructions_computation()}))
    747             << "Fusion HLO calls computations other than the "
    748                "fused_instructions_computation: "
    749             << instruction->ToString()
    750             << " instruction->fused_instructions_computation(): "
    751             << instruction->fused_instructions_computation()->ToString()
    752             << " instruction->called_computations(): "
    753             << ComputationsToString(instruction->called_computations());
    754 
    755         for (const auto& fused : instruction->fused_instructions()) {
    756           TF_RET_CHECK(fused->parent() ==
    757                        instruction->fused_instructions_computation())
    758               << "Fused HLO was missing a parent: " << fused->ToString()
    759               << " parent: " << fused->parent()
    760               << " computation: " << computation;
    761         }
    762       } else if (instruction->opcode() == HloOpcode::kBroadcast) {
    763         // If you see this failure then someone has confused the difference
    764         // between the HLO broadcast op, and the UserComputation broadcast
    765         // op.  See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
    766         // or ComputationLowerer::Visit()
    767         TF_RET_CHECK(instruction->dimensions().size() ==
    768                      ShapeUtil::Rank(instruction->operand(0)->shape()))
    769             << "Broadcast HLO has invalid number of dimensions.";
    770       } else if (instruction->opcode() == HloOpcode::kWhile) {
    771         auto* while_cond = instruction->while_condition();
    772         auto* while_body = instruction->while_body();
    773         TF_RET_CHECK(while_cond->num_parameters() == 1)
    774             << "While condition must have exactly 1 parameter; had "
    775             << while_cond->num_parameters() << ": " << while_cond->ToString();
    776         TF_RET_CHECK(while_body->num_parameters() == 1)
    777             << "While body must have exactly 1 parameter; had "
    778             << while_body->num_parameters() << ": " << while_body->ToString();
    779         TF_RET_CHECK(instruction->operand_count() == 1)
    780             << "While loop must have exactly one operand; had "
    781             << instruction->operand_count() << ": " << instruction->ToString();
    782 
    783         auto* init = instruction->operand(0);
    784         auto* cond_param = while_cond->parameter_instruction(0);
    785         TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape()))
    786             << "While condition's parameter must have the same shape as the "
    787                "loop's 'init'. init: "
    788             << init->ToString() << ", param: " << cond_param->ToString();
    789         auto* cond_root = while_cond->root_instruction();
    790         TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(),
    791                                            ShapeUtil::MakeShape(PRED, {})))
    792             << "While condition should have shape PRED: "
    793             << cond_root->ToString();
    794 
    795         auto* body_param = while_body->parameter_instruction(0);
    796         TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape()))
    797             << "While body's parameter must have the same shape as the loop's "
    798                "'init'. init: "
    799             << init->ToString() << ", param: " << body_param->ToString();
    800         auto* body_root = while_body->root_instruction();
    801         TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape()))
    802             << "While body should have same shape as the loop's 'init'. init: "
    803             << init->ToString() << ", body: " << body_root->ToString();
    804       }
    805 
    806       auto previous = instructions.find(instruction->name());
    807       TF_RET_CHECK(previous == instructions.end())
    808           << "HLO has name that is not unique within module:\n"
    809           << instruction->ToString()
    810           << " in computation: " << computation->name()
    811           << "\nPrevious HLO with same name:\n"
    812           << previous->second->ToString()
    813           << " in computation: " << previous->second->parent()->name();
    814       instructions[instruction->name()] = instruction;
    815     }
    816 
    817     std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
    818     TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
    819   }
    820 
    821   return false;
    822 }
    823 
    824 }  // namespace xla
    825