1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 18 19 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 20 21 #include "tensorflow/compiler/xla/service/shape_inference.h" 22 23 namespace xla { 24 25 // Visitor which verifies that the output shape is correctly set. Verifies 26 // against the inferred shape for the instruction. 27 // TODO(b/26024837): Check output shape for all instruction types. 28 class ShapeVerifier : public DfsHloVisitor { 29 public: 30 explicit ShapeVerifier() : allow_mixed_precision_(false) {} 31 explicit ShapeVerifier(bool allow_mixed_precision) 32 : allow_mixed_precision_(allow_mixed_precision) {} 33 34 Status HandleElementwiseUnary(HloInstruction* hlo) override; 35 Status HandleElementwiseBinary(HloInstruction* hlo) override; 36 Status HandleClamp(HloInstruction* clamp) override; 37 Status HandleSelect(HloInstruction* select) override; 38 Status HandleConcatenate(HloInstruction* concatenate) override; 39 Status HandleConvert(HloInstruction* convert) override; 40 Status HandleBitcastConvert(HloInstruction* convert) override; 41 Status HandleCopy(HloInstruction* copy) override; 42 Status HandleDot(HloInstruction* dot) override; 43 Status HandleConvolution(HloInstruction* convolution) override; 44 Status HandleFft(HloInstruction* fft) override; 45 Status HandleCrossReplicaSum(HloInstruction* crs) override; 46 Status HandleReducePrecision(HloInstruction* reduce_precision) override; 47 Status HandleInfeed(HloInstruction*) override; 48 Status HandleOutfeed(HloInstruction*) override; 49 Status HandleRng(HloInstruction*) override; 50 Status HandleReverse(HloInstruction* reverse) override; 51 Status HandleSort(HloInstruction* sort) override; 52 Status HandleConstant(HloInstruction* constant) override; 53 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 54 Status HandleReduce(HloInstruction* reduce) override; 55 Status HandleBitcast(HloInstruction* bitcast) override; 56 Status HandleBroadcast(HloInstruction* broadcast) override; 57 Status HandleReshape(HloInstruction* reshape) override; 58 Status HandleTranspose(HloInstruction* transpose) override; 59 Status HandleParameter(HloInstruction*) override; 60 Status HandleFusion(HloInstruction*) override; 61 Status HandleCall(HloInstruction* call) override; 62 Status HandleCustomCall(HloInstruction*) override; 63 Status HandleHostCompute(HloInstruction*) override; 64 Status HandleSlice(HloInstruction* slice) override; 65 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 66 Status HandleDynamicUpdateSlice( 67 HloInstruction* dynamic_update_slice) override; 68 Status HandleTuple(HloInstruction* tuple) override; 69 Status HandleMap(HloInstruction* map) override; 70 Status HandleReduceWindow(HloInstruction* reduce_window) override; 71 Status HandleSelectAndScatter(HloInstruction* instruction) override; 72 Status HandleWhile(HloInstruction* xla_while) override; 73 Status HandleConditional(HloInstruction* conditional) override; 74 Status HandlePad(HloInstruction* pad) override; 75 Status HandleSend(HloInstruction* send) override; 76 Status HandleSendDone(HloInstruction* send_done) override; 77 Status HandleRecv(HloInstruction* recv) override; 78 Status HandleRecvDone(HloInstruction* recv_done) override; 79 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; 80 Status HandleBatchNormInference( 81 HloInstruction* batch_norm_inference) override; 82 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; 83 Status HandleGather(HloInstruction* gather) override; 84 85 Status FinishVisit(HloInstruction*) override { 86 return tensorflow::Status::OK(); 87 } 88 89 protected: 90 // Check the instruction's shape against the shape given by ShapeInference 91 // and return an appropriate error if there is a mismatch. 92 Status CheckShape(const HloInstruction* instruction, 93 const Shape& inferred_shape); 94 95 // Overload which takes a StatusOr to reduce boilerplate in the caller. 96 Status CheckShape(const HloInstruction* instruction, 97 const StatusOr<Shape>& inferred_shape_status); 98 99 // Check a unary (binary, etc) instruction's shape against the inferred shape. 100 Status CheckUnaryShape(const HloInstruction* instruction); 101 Status CheckBinaryShape(const HloInstruction* instruction); 102 Status CheckTernaryShape(const HloInstruction* instruction); 103 Status CheckVariadicShape(const HloInstruction* instruction); 104 105 // Checks if the given two instructions shares the same channel id. 106 Status CheckSameChannel(const HloInstruction* instr1, 107 const HloInstruction* instr2); 108 109 private: 110 // Whether the inputs and output of an instruction can contain both F32s and 111 // BF16s. Tuples that include both F32s and BF16s are allowed regardless of 112 // this flag. 113 bool allow_mixed_precision_; 114 }; 115 116 // HLO pass that verifies invariants of HLO instructions for each computation in 117 // the module. 118 class HloVerifier : public HloPassInterface { 119 public: 120 using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; 121 122 // Uses standard shape inference. 123 explicit HloVerifier() 124 : shape_verifier_factory_( 125 [] { return MakeUnique<ShapeVerifier>(false); }) {} 126 127 explicit HloVerifier(bool allow_mixed_precision) 128 : shape_verifier_factory_([allow_mixed_precision] { 129 return MakeUnique<ShapeVerifier>(allow_mixed_precision); 130 }) {} 131 132 // Uses custom shape verification. 133 explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) 134 : shape_verifier_factory_(std::move(shape_verifier_factory)) {} 135 136 ~HloVerifier() override = default; 137 tensorflow::StringPiece name() const override { return "verifier"; } 138 139 // Note: always returns false (no instructions are ever modified by this 140 // pass). 141 StatusOr<bool> Run(HloModule* module) override; 142 143 private: 144 // CHECKs various invariants of a fusion instruction. 145 Status CheckFusionInstruction(HloInstruction* fusion) const; 146 147 // Creates a ShapeVerifier that checks that shapes match inferred 148 // expectations. This is a factory function because ShapeVerifier, Note that 149 // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object 150 // for each run of the verifier. 151 ShapeVerifierFactory shape_verifier_factory_; 152 }; 153 154 } // namespace xla 155 156 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 157