Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
     18 
     19 #include <type_traits>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/status.h"
     25 #include "tensorflow/compiler/xla/types.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/core/stringpiece.h"
     29 #include "tensorflow/core/lib/gtl/array_slice.h"
     30 #include "tensorflow/core/lib/gtl/flatmap.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 
     36 class HloComputation;
     37 class HloInstruction;
     38 
     39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
     40 // instruction, all its operands were already visited. User code can subclass
     41 // this to iterate over an HloInstruction DAG. The Handle* routines have
     42 // operands / data unpacked for ease of use in the visitor subclass.
     43 //
     44 // No instruction will ever be visited twice; however, the root instruction will
     45 // be reported again when the traversal is done via a call to FinishVisit.
     46 //
     47 // A subclass must override at least
     48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
     49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
     50 // The default Handle methods for (unary, binary) ops call
     51 // (HandleElementwiseUnary, HandleElementwiseBinary).
     52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
     53 // "unimplemented" error status.
     54 //
     55 // Note: this may change to an iterator in the future for flexibility purposes.
     56 //
     57 // Users should not use this class directly, but use the type-aliases
     58 // DfsHloVisitor/ConstDfsHloVisitor instead.
     59 template <typename HloInstructionPtr>
     60 class DfsHloVisitorBase {
     61   static_assert(
     62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
     63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
     64       "Template argument expected to be HloInstruction* or const "
     65       "HloInstruction*");
     66 
     67  public:
     68   DfsHloVisitorBase() {}
     69   virtual ~DfsHloVisitorBase() {}
     70 
     71   // These routines are self-descriptive, see class comment for usage
     72   // information.
     73 
     74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
     75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
     76 
     77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
     78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
     79   virtual Status HandleMaximum(HloInstructionPtr hlo) {
     80     return HandleElementwiseBinary(hlo);
     81   }
     82   virtual Status HandleMinimum(HloInstructionPtr hlo) {
     83     return HandleElementwiseBinary(hlo);
     84   }
     85   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
     86   virtual Status HandleConvert(HloInstructionPtr hlo) {
     87     return HandleElementwiseUnary(hlo);
     88   }
     89   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
     90     return HandleElementwiseUnary(hlo);
     91   }
     92   virtual Status HandleCopy(HloInstructionPtr hlo) {
     93     return HandleElementwiseUnary(hlo);
     94   }
     95   virtual Status HandleComplex(HloInstructionPtr hlo) {
     96     return HandleElementwiseBinary(hlo);
     97   }
     98   virtual Status HandleMultiply(HloInstructionPtr hlo) {
     99     return HandleElementwiseBinary(hlo);
    100   }
    101   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
    102   virtual Status HandlePower(HloInstructionPtr hlo) {
    103     return HandleElementwiseBinary(hlo);
    104   }
    105   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
    106   virtual Status HandleFft(HloInstructionPtr fft) = 0;
    107   virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0;
    108   virtual Status HandleCompare(HloInstructionPtr hlo) {
    109     return HandleElementwiseBinary(hlo);
    110   }
    111   virtual Status HandleAdd(HloInstructionPtr hlo) {
    112     return HandleElementwiseBinary(hlo);
    113   }
    114   virtual Status HandleDivide(HloInstructionPtr hlo) {
    115     return HandleElementwiseBinary(hlo);
    116   }
    117   virtual Status HandleRemainder(HloInstructionPtr hlo) {
    118     return HandleElementwiseBinary(hlo);
    119   }
    120   virtual Status HandleSubtract(HloInstructionPtr hlo) {
    121     return HandleElementwiseBinary(hlo);
    122   }
    123   virtual Status HandleAbs(HloInstructionPtr hlo) {
    124     return HandleElementwiseUnary(hlo);
    125   }
    126   virtual Status HandleAtan2(HloInstructionPtr hlo) {
    127     return HandleElementwiseBinary(hlo);
    128   }
    129   virtual Status HandleRound(HloInstructionPtr hlo) {
    130     return HandleElementwiseUnary(hlo);
    131   }
    132   virtual Status HandleSign(HloInstructionPtr hlo) {
    133     return HandleElementwiseUnary(hlo);
    134   }
    135   virtual Status HandleNegate(HloInstructionPtr hlo) {
    136     return HandleElementwiseUnary(hlo);
    137   }
    138   virtual Status HandleExp(HloInstructionPtr hlo) {
    139     return HandleElementwiseUnary(hlo);
    140   }
    141   virtual Status HandleFloor(HloInstructionPtr hlo) {
    142     return HandleElementwiseUnary(hlo);
    143   }
    144   virtual Status HandleCeil(HloInstructionPtr hlo) {
    145     return HandleElementwiseUnary(hlo);
    146   }
    147   virtual Status HandleLog(HloInstructionPtr hlo) {
    148     return HandleElementwiseUnary(hlo);
    149   }
    150   virtual Status HandleCos(HloInstructionPtr hlo) {
    151     return HandleElementwiseUnary(hlo);
    152   }
    153   virtual Status HandleSin(HloInstructionPtr hlo) {
    154     return HandleElementwiseUnary(hlo);
    155   }
    156   virtual Status HandleTanh(HloInstructionPtr hlo) {
    157     return HandleElementwiseUnary(hlo);
    158   }
    159   virtual Status HandleReal(HloInstructionPtr hlo) {
    160     return HandleElementwiseUnary(hlo);
    161   }
    162   virtual Status HandleImag(HloInstructionPtr hlo) {
    163     return HandleElementwiseUnary(hlo);
    164   }
    165   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
    166     return HandleElementwiseUnary(hlo);
    167   }
    168   virtual Status HandleAnd(HloInstructionPtr hlo) {
    169     return HandleElementwiseBinary(hlo);
    170   }
    171   virtual Status HandleNot(HloInstructionPtr hlo) {
    172     return HandleElementwiseUnary(hlo);
    173   }
    174   virtual Status HandleOr(HloInstructionPtr hlo) {
    175     return HandleElementwiseBinary(hlo);
    176   }
    177   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
    178     return HandleElementwiseBinary(hlo);
    179   }
    180   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
    181     return HandleElementwiseBinary(hlo);
    182   }
    183   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
    184     return HandleElementwiseBinary(hlo);
    185   }
    186 
    187   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
    188     return HandleElementwiseUnary(hlo);
    189   }
    190 
    191   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
    192   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
    193   virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
    194   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
    195   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
    196   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
    197   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
    198   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
    199   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
    200   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
    201   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
    202   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
    203   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
    204   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
    205   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
    206   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
    207   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
    208   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
    209   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
    210   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
    211   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
    212   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
    213   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
    214   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
    215   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
    216   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
    217   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
    218 
    219   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
    220 
    221   virtual Status HandleSend(HloInstructionPtr send) = 0;
    222   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
    223 
    224   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
    225   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
    226 
    227   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
    228 
    229   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
    230 
    231   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
    232 
    233   // Invoked to inform the visitor that the traversal has completed, and that
    234   // the root was "root".
    235   virtual Status FinishVisit(HloInstructionPtr root) = 0;
    236 
    237   // 3 possible visitation states of HLO instructions. Each instruction's
    238   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
    239   enum VisitState {
    240     kNotVisited = 0,
    241     kVisiting = 1,
    242     kVisited = 2,
    243   };
    244 
    245   VisitState GetVisitState(int id) { return visit_state_.GetState(id); }
    246   VisitState GetVisitState(const HloInstruction& instruction);
    247 
    248   // Resize internal state if necessary to hold state for ids <= num.
    249   // This call is purely a performance hint and can be omitted without
    250   // affecting correctness.
    251   void ReserveVisitStates(int num) { visit_state_.Reserve(num); }
    252 
    253   // Useful when we want to visit the same computation more than once with the
    254   // same visitor.
    255   void ResetVisitStates() { visit_state_.Reset(); }
    256 
    257   void SetVisitState(int id, VisitState state) {
    258     visit_state_.SetState(id, state);
    259   }
    260 
    261   // Sets the visitation state of the given instruction as kVisiting.
    262   //
    263   // Precondition: current state must be kNotVisited.
    264   void SetVisiting(const HloInstruction& instruction);
    265 
    266   // Sets the visitation state of the given instruction as kVisited.
    267   //
    268   // Precondition: current state must be either kNotVisited or kVisiting.
    269   void SetVisited(const HloInstruction& instruction);
    270 
    271   // Returns whether the state of the given instruction is kVisiting.
    272   bool IsVisiting(const HloInstruction& instruction) {
    273     return GetVisitState(instruction) == kVisiting;
    274   }
    275 
    276   // Returns whether the state of the given instruction is kVisited.
    277   bool DidVisit(const HloInstruction& instruction) {
    278     return GetVisitState(instruction) == kVisited;
    279   }
    280 
    281   // Returns whether the state of the given instruction is kNotVisited.
    282   bool NotVisited(const HloInstruction& instruction) {
    283     return GetVisitState(instruction) == kNotVisited;
    284   }
    285 
    286   // This method should be overridden by subclasses that wish to run some
    287   // operation on an op before its Handle* visitor method is called.
    288   //
    289   // For any HLO op, the order of calls is:
    290   //
    291   //   Preprocess(op);
    292   //   Handle/OpType/(op);
    293   //   Postprocess(op);
    294   //
    295   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
    296   // own preprocessing.
    297   virtual Status Preprocess(HloInstructionPtr hlo);
    298 
    299   // This method should be overridden by subclasses that wish to run some
    300   // operation on an op after its Handle* visitor method is called. See
    301   // Preprocess for more details.
    302   //
    303   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
    304   // own postprocessing.
    305   virtual Status Postprocess(HloInstructionPtr hlo);
    306 
    307  private:
    308   class DFSVisitStates {
    309    public:
    310     DFSVisitStates() {}
    311     void Reserve(uint64 num) {
    312       states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord);
    313     }
    314     VisitState GetState(uint64 id) {
    315       uint64 word_index = id / kStatesPerWord;
    316       if (word_index >= states_.size()) {
    317         return VisitState::kNotVisited;
    318       }
    319       static_assert(static_cast<int>(VisitState::kVisited) < 3,
    320                     "VisitState must fit in two bits");
    321       uint64 w = states_[word_index];
    322       uint32 shift = 2 * (id % kStatesPerWord);  // 2 bits per state
    323       return static_cast<VisitState>((w >> shift) & 0x3);
    324     }
    325     void SetState(uint64 id, VisitState state) {
    326       uint64 word_index = id / kStatesPerWord;
    327       if (word_index >= states_.size()) {
    328         states_.resize(word_index + 1, 0);
    329       }
    330       uint64* w = &states_[word_index];
    331       uint32 shift = 2 * (id % kStatesPerWord);  // 2 bits per state
    332       uint64 mask = 0x3ull << shift;
    333       *w = (*w & ~mask) | (static_cast<uint64>(state) << shift);
    334       DCHECK_EQ(GetState(id), state);
    335     }
    336     void Reset() { states_.clear(); }
    337 
    338    private:
    339     static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/;
    340     // Map from id to two-bit states.  We store 32 such states per 64-bit
    341     // value
    342     std::vector<uint64> states_;
    343   };
    344 
    345   DFSVisitStates visit_state_;
    346 
    347   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
    348 };
    349 
    350 // Users should use one of these two type aliases, which are the only two valid
    351 // instantiations of DfsHloVisitorBase.
    352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
    353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
    354 
    355 }  // namespace xla
    356 
    357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
    358