Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
     18 
     19 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     20 
     21 #include "tensorflow/compiler/xla/service/shape_inference.h"
     22 
     23 namespace xla {
     24 
     25 // Visitor which verifies that the output shape is correctly set. Verifies
     26 // against the inferred shape for the instruction.
     27 // TODO(b/26024837): Check output shape for all instruction types.
     28 class ShapeVerifier : public DfsHloVisitor {
     29  public:
     30   explicit ShapeVerifier() : allow_mixed_precision_(false) {}
     31   explicit ShapeVerifier(bool allow_mixed_precision)
     32       : allow_mixed_precision_(allow_mixed_precision) {}
     33 
     34   Status HandleElementwiseUnary(HloInstruction* hlo) override;
     35   Status HandleElementwiseBinary(HloInstruction* hlo) override;
     36   Status HandleClamp(HloInstruction* clamp) override;
     37   Status HandleSelect(HloInstruction* select) override;
     38   Status HandleConcatenate(HloInstruction* concatenate) override;
     39   Status HandleConvert(HloInstruction* convert) override;
     40   Status HandleBitcastConvert(HloInstruction* convert) override;
     41   Status HandleCopy(HloInstruction* copy) override;
     42   Status HandleDot(HloInstruction* dot) override;
     43   Status HandleConvolution(HloInstruction* convolution) override;
     44   Status HandleFft(HloInstruction* fft) override;
     45   Status HandleCrossReplicaSum(HloInstruction* crs) override;
     46   Status HandleReducePrecision(HloInstruction* reduce_precision) override;
     47   Status HandleInfeed(HloInstruction*) override;
     48   Status HandleOutfeed(HloInstruction*) override;
     49   Status HandleRng(HloInstruction*) override;
     50   Status HandleReverse(HloInstruction* reverse) override;
     51   Status HandleSort(HloInstruction* sort) override;
     52   Status HandleConstant(HloInstruction* constant) override;
     53   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
     54   Status HandleReduce(HloInstruction* reduce) override;
     55   Status HandleBitcast(HloInstruction* bitcast) override;
     56   Status HandleBroadcast(HloInstruction* broadcast) override;
     57   Status HandleReshape(HloInstruction* reshape) override;
     58   Status HandleTranspose(HloInstruction* transpose) override;
     59   Status HandleParameter(HloInstruction*) override;
     60   Status HandleFusion(HloInstruction*) override;
     61   Status HandleCall(HloInstruction* call) override;
     62   Status HandleCustomCall(HloInstruction*) override;
     63   Status HandleHostCompute(HloInstruction*) override;
     64   Status HandleSlice(HloInstruction* slice) override;
     65   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
     66   Status HandleDynamicUpdateSlice(
     67       HloInstruction* dynamic_update_slice) override;
     68   Status HandleTuple(HloInstruction* tuple) override;
     69   Status HandleMap(HloInstruction* map) override;
     70   Status HandleReduceWindow(HloInstruction* reduce_window) override;
     71   Status HandleSelectAndScatter(HloInstruction* instruction) override;
     72   Status HandleWhile(HloInstruction* xla_while) override;
     73   Status HandleConditional(HloInstruction* conditional) override;
     74   Status HandlePad(HloInstruction* pad) override;
     75   Status HandleSend(HloInstruction* send) override;
     76   Status HandleSendDone(HloInstruction* send_done) override;
     77   Status HandleRecv(HloInstruction* recv) override;
     78   Status HandleRecvDone(HloInstruction* recv_done) override;
     79   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
     80   Status HandleBatchNormInference(
     81       HloInstruction* batch_norm_inference) override;
     82   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
     83   Status HandleGather(HloInstruction* gather) override;
     84 
     85   Status FinishVisit(HloInstruction*) override {
     86     return tensorflow::Status::OK();
     87   }
     88 
     89  protected:
     90   // Check the instruction's shape against the shape given by ShapeInference
     91   // and return an appropriate error if there is a mismatch.
     92   Status CheckShape(const HloInstruction* instruction,
     93                     const Shape& inferred_shape);
     94 
     95   // Overload which takes a StatusOr to reduce boilerplate in the caller.
     96   Status CheckShape(const HloInstruction* instruction,
     97                     const StatusOr<Shape>& inferred_shape_status);
     98 
     99   // Check a unary (binary, etc) instruction's shape against the inferred shape.
    100   Status CheckUnaryShape(const HloInstruction* instruction);
    101   Status CheckBinaryShape(const HloInstruction* instruction);
    102   Status CheckTernaryShape(const HloInstruction* instruction);
    103   Status CheckVariadicShape(const HloInstruction* instruction);
    104 
    105   // Checks if the given two instructions shares the same channel id.
    106   Status CheckSameChannel(const HloInstruction* instr1,
    107                           const HloInstruction* instr2);
    108 
    109  private:
    110   // Whether the inputs and output of an instruction can contain both F32s and
    111   // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
    112   // this flag.
    113   bool allow_mixed_precision_;
    114 };
    115 
    116 // HLO pass that verifies invariants of HLO instructions for each computation in
    117 // the module.
    118 class HloVerifier : public HloPassInterface {
    119  public:
    120   using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
    121 
    122   // Uses standard shape inference.
    123   explicit HloVerifier()
    124       : shape_verifier_factory_(
    125             [] { return MakeUnique<ShapeVerifier>(false); }) {}
    126 
    127   explicit HloVerifier(bool allow_mixed_precision)
    128       : shape_verifier_factory_([allow_mixed_precision] {
    129           return MakeUnique<ShapeVerifier>(allow_mixed_precision);
    130         }) {}
    131 
    132   // Uses custom shape verification.
    133   explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
    134       : shape_verifier_factory_(std::move(shape_verifier_factory)) {}
    135 
    136   ~HloVerifier() override = default;
    137   tensorflow::StringPiece name() const override { return "verifier"; }
    138 
    139   // Note: always returns false (no instructions are ever modified by this
    140   // pass).
    141   StatusOr<bool> Run(HloModule* module) override;
    142 
    143  private:
    144   // CHECKs various invariants of a fusion instruction.
    145   Status CheckFusionInstruction(HloInstruction* fusion) const;
    146 
    147   // Creates a ShapeVerifier that checks that shapes match inferred
    148   // expectations.  This is a factory function because ShapeVerifier,  Note that
    149   // ShapeVerifier, being a DfsHloVisitor, is stateful.  We want a clean object
    150   // for each run of the verifier.
    151   ShapeVerifierFactory shape_verifier_factory_;
    152 };
    153 
    154 }  // namespace xla
    155 
    156 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
    157