Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
     18 
     19 #include <type_traits>
     20 #include <vector>
     21 
     22 #include "absl/container/flat_hash_map.h"
     23 #include "absl/strings/string_view.h"
     24 #include "absl/types/span.h"
     25 #include "tensorflow/compiler/xla/literal.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/status.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/xla_data.pb.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace xla {
     35 
     36 class HloComputation;
     37 class HloInstruction;
     38 
     39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an
     40 // instruction, all its operands were already visited. User code can subclass
     41 // this to iterate over an HloInstruction DAG. The Handle* routines have
     42 // operands / data unpacked for ease of use in the visitor subclass.
     43 //
     44 // No instruction will ever be visited twice; however, the root instruction will
     45 // be reported again when the traversal is done via a call to FinishVisit.
     46 //
     47 // A subclass must override at least
     48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and
     49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)).
     50 // The default Handle methods for (unary, binary) ops call
     51 // (HandleElementwiseUnary, HandleElementwiseBinary).
     52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
     53 // "unimplemented" error status.
     54 //
     55 // Note: this may change to an iterator in the future for flexibility purposes.
     56 //
     57 // Users should not use this class directly, but use the type-aliases
     58 // DfsHloVisitor/ConstDfsHloVisitor instead.
     59 template <typename HloInstructionPtr>
     60 class DfsHloVisitorBase {
     61   static_assert(
     62       std::is_same<HloInstruction*, HloInstructionPtr>::value ||
     63           std::is_same<const HloInstruction*, HloInstructionPtr>::value,
     64       "Template argument expected to be HloInstruction* or const "
     65       "HloInstruction*");
     66 
     67  public:
     68   DfsHloVisitorBase() {}
     69   virtual ~DfsHloVisitorBase() {}
     70 
     71   // These routines are self-descriptive, see class comment for usage
     72   // information.
     73 
     74   virtual Status HandleElementwiseUnary(HloInstructionPtr hlo);
     75   virtual Status HandleElementwiseBinary(HloInstructionPtr hlo);
     76 
     77   virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
     78   virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
     79   virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
     80   virtual Status HandleMaximum(HloInstructionPtr hlo) {
     81     return HandleElementwiseBinary(hlo);
     82   }
     83   virtual Status HandleMinimum(HloInstructionPtr hlo) {
     84     return HandleElementwiseBinary(hlo);
     85   }
     86   virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0;
     87   virtual Status HandleConvert(HloInstructionPtr hlo) {
     88     return HandleElementwiseUnary(hlo);
     89   }
     90   virtual Status HandleBitcastConvert(HloInstructionPtr hlo) {
     91     return HandleElementwiseUnary(hlo);
     92   }
     93   virtual Status HandleCopy(HloInstructionPtr hlo) {
     94     return HandleElementwiseUnary(hlo);
     95   }
     96   virtual Status HandleComplex(HloInstructionPtr hlo) {
     97     return HandleElementwiseBinary(hlo);
     98   }
     99   virtual Status HandleMultiply(HloInstructionPtr hlo) {
    100     return HandleElementwiseBinary(hlo);
    101   }
    102   virtual Status HandleDot(HloInstructionPtr hlo) = 0;
    103   virtual Status HandlePower(HloInstructionPtr hlo) {
    104     return HandleElementwiseBinary(hlo);
    105   }
    106   virtual Status HandleSqrt(HloInstructionPtr hlo) {
    107     return HandleElementwiseUnary(hlo);
    108   }
    109   virtual Status HandleRsqrt(HloInstructionPtr hlo) {
    110     return HandleElementwiseUnary(hlo);
    111   }
    112   virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
    113   virtual Status HandleFft(HloInstructionPtr fft) = 0;
    114   virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
    115   virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
    116   virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
    117   virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
    118   virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
    119   virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0;
    120   virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0;
    121   virtual Status HandleCompare(HloInstructionPtr hlo) {
    122     return HandleElementwiseBinary(hlo);
    123   }
    124   virtual Status HandleAdd(HloInstructionPtr hlo) {
    125     return HandleElementwiseBinary(hlo);
    126   }
    127   virtual Status HandleDivide(HloInstructionPtr hlo) {
    128     return HandleElementwiseBinary(hlo);
    129   }
    130   virtual Status HandleRemainder(HloInstructionPtr hlo) {
    131     return HandleElementwiseBinary(hlo);
    132   }
    133   virtual Status HandleSubtract(HloInstructionPtr hlo) {
    134     return HandleElementwiseBinary(hlo);
    135   }
    136   virtual Status HandleAbs(HloInstructionPtr hlo) {
    137     return HandleElementwiseUnary(hlo);
    138   }
    139   virtual Status HandleAtan2(HloInstructionPtr hlo) {
    140     return HandleElementwiseBinary(hlo);
    141   }
    142   virtual Status HandleRound(HloInstructionPtr hlo) {
    143     return HandleElementwiseUnary(hlo);
    144   }
    145   virtual Status HandleSign(HloInstructionPtr hlo) {
    146     return HandleElementwiseUnary(hlo);
    147   }
    148   virtual Status HandleNegate(HloInstructionPtr hlo) {
    149     return HandleElementwiseUnary(hlo);
    150   }
    151   virtual Status HandleExp(HloInstructionPtr hlo) {
    152     return HandleElementwiseUnary(hlo);
    153   }
    154   virtual Status HandleExpm1(HloInstructionPtr hlo) {
    155     return HandleElementwiseUnary(hlo);
    156   }
    157   virtual Status HandleFloor(HloInstructionPtr hlo) {
    158     return HandleElementwiseUnary(hlo);
    159   }
    160   virtual Status HandleCeil(HloInstructionPtr hlo) {
    161     return HandleElementwiseUnary(hlo);
    162   }
    163   virtual Status HandleLog(HloInstructionPtr hlo) {
    164     return HandleElementwiseUnary(hlo);
    165   }
    166   virtual Status HandleClz(HloInstructionPtr hlo) {
    167     return HandleElementwiseUnary(hlo);
    168   }
    169   virtual Status HandleLog1p(HloInstructionPtr hlo) {
    170     return HandleElementwiseUnary(hlo);
    171   }
    172   virtual Status HandleCos(HloInstructionPtr hlo) {
    173     return HandleElementwiseUnary(hlo);
    174   }
    175   virtual Status HandleSin(HloInstructionPtr hlo) {
    176     return HandleElementwiseUnary(hlo);
    177   }
    178   virtual Status HandleTanh(HloInstructionPtr hlo) {
    179     return HandleElementwiseUnary(hlo);
    180   }
    181   virtual Status HandleReal(HloInstructionPtr hlo) {
    182     return HandleElementwiseUnary(hlo);
    183   }
    184   virtual Status HandleImag(HloInstructionPtr hlo) {
    185     return HandleElementwiseUnary(hlo);
    186   }
    187   virtual Status HandleIsFinite(HloInstructionPtr hlo) {
    188     return HandleElementwiseUnary(hlo);
    189   }
    190   virtual Status HandleAnd(HloInstructionPtr hlo) {
    191     return HandleElementwiseBinary(hlo);
    192   }
    193   virtual Status HandleNot(HloInstructionPtr hlo) {
    194     return HandleElementwiseUnary(hlo);
    195   }
    196   virtual Status HandleOr(HloInstructionPtr hlo) {
    197     return HandleElementwiseBinary(hlo);
    198   }
    199   virtual Status HandleXor(HloInstructionPtr hlo) {
    200     return HandleElementwiseBinary(hlo);
    201   }
    202   virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
    203     return HandleElementwiseBinary(hlo);
    204   }
    205   virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) {
    206     return HandleElementwiseBinary(hlo);
    207   }
    208   virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) {
    209     return HandleElementwiseBinary(hlo);
    210   }
    211 
    212   virtual Status HandleReducePrecision(HloInstructionPtr hlo) {
    213     return HandleElementwiseUnary(hlo);
    214   }
    215 
    216   virtual Status HandleDomain(HloInstructionPtr hlo) {
    217     return HandleElementwiseUnary(hlo);
    218   }
    219 
    220   virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
    221   virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
    222   virtual Status HandleRng(HloInstructionPtr hlo) = 0;
    223   virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
    224   virtual Status HandleSort(HloInstructionPtr hlo) = 0;
    225   virtual Status HandleConstant(HloInstructionPtr hlo) = 0;
    226   virtual Status HandleIota(HloInstructionPtr hlo) = 0;
    227   virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0;
    228   virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
    229   virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
    230   virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
    231   virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
    232   virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
    233   virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
    234   virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
    235   virtual Status HandleCall(HloInstructionPtr hlo) = 0;
    236   virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0;
    237   virtual Status HandleSlice(HloInstructionPtr hlo) = 0;
    238   virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0;
    239   virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0;
    240   virtual Status HandleTuple(HloInstructionPtr hlo) = 0;
    241   virtual Status HandleMap(HloInstructionPtr hlo) = 0;
    242   virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0;
    243   virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
    244   virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
    245   virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
    246   virtual Status HandleGather(HloInstructionPtr hlo) = 0;
    247   virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
    248 
    249   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
    250 
    251   virtual Status HandleSend(HloInstructionPtr send) = 0;
    252   virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
    253 
    254   virtual Status HandleRecv(HloInstructionPtr recv) = 0;
    255   virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
    256 
    257   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
    258 
    259   virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0;
    260 
    261   virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
    262 
    263   virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
    264   virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
    265 
    266   // Invoked to inform the visitor that the traversal has completed, and that
    267   // the root was "root".
    268   virtual Status FinishVisit(HloInstructionPtr root) = 0;
    269 
    270   // 3 possible visitation states of HLO instructions. Each instruction's
    271   // state only flows one way: kNotVisited -> kVisiting -> kVisited.
    272   enum VisitState {
    273     kNotVisited = 0,
    274     kVisiting = 1,
    275     kVisited = 2,
    276   };
    277 
    278   VisitState GetVisitState(int id) {
    279     auto iter = visit_state_.find(id);
    280     if (iter == visit_state_.end()) {
    281       return VisitState::kNotVisited;
    282     }
    283     return iter->second;
    284   }
    285   VisitState GetVisitState(const HloInstruction& instruction);
    286 
    287   // Resize internal state if necessary to hold state for ids <= num.
    288   // This call is purely a performance hint and can be omitted without
    289   // affecting correctness.
    290   void ReserveVisitStates(int num) { visit_state_.reserve(num); }
    291 
    292   // Useful when we want to visit the same computation more than once with the
    293   // same visitor.
    294   void ResetVisitStates() { visit_state_.clear(); }
    295 
    296   void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }
    297 
    298   // Sets the visitation state of the given instruction as kVisiting.
    299   //
    300   // Precondition: current state must be kNotVisited.
    301   void SetVisiting(const HloInstruction& instruction);
    302 
    303   // Sets the visitation state of the given instruction as kVisited.
    304   //
    305   // Precondition: current state must be either kNotVisited or kVisiting.
    306   void SetVisited(const HloInstruction& instruction);
    307 
    308   // Returns whether the state of the given instruction is kVisiting.
    309   bool IsVisiting(const HloInstruction& instruction) {
    310     return GetVisitState(instruction) == kVisiting;
    311   }
    312 
    313   // Returns whether the state of the given instruction is kVisited.
    314   bool DidVisit(const HloInstruction& instruction) {
    315     return GetVisitState(instruction) == kVisited;
    316   }
    317 
    318   // Returns whether the state of the given instruction is kNotVisited.
    319   bool NotVisited(const HloInstruction& instruction) {
    320     return GetVisitState(instruction) == kNotVisited;
    321   }
    322 
    323   // This method should be overridden by subclasses that wish to run some
    324   // operation on an op before its Handle* visitor method is called.
    325   //
    326   // For any HLO op, the order of calls is:
    327   //
    328   //   Preprocess(op);
    329   //   Handle/OpType/(op);
    330   //   Postprocess(op);
    331   //
    332   // Overriding methods should call DfsHloVisitor::Preprocess before doing their
    333   // own preprocessing.
    334   virtual Status Preprocess(HloInstructionPtr hlo);
    335 
    336   // This method should be overridden by subclasses that wish to run some
    337   // operation on an op after its Handle* visitor method is called. See
    338   // Preprocess for more details.
    339   //
    340   // Overriding methods should call DfsHloVisitor::Postprocess after doing their
    341   // own postprocessing.
    342   virtual Status Postprocess(HloInstructionPtr hlo);
    343 
    344  private:
    345   absl::flat_hash_map<int, VisitState> visit_state_;
    346 
    347   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase);
    348 };
    349 
    350 // Users should use one of these two type aliases, which are the only two valid
    351 // instantiations of DfsHloVisitorBase.
    352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>;
    353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>;
    354 
    355 }  // namespace xla
    356 
    357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
    358