Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
     18 
     19 #include "absl/strings/string_view.h"
     20 #include "absl/types/span.h"
     21 #include "tensorflow/compiler/xla/literal.h"
     22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/types.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 
     32 class HloComputation;
     33 class HloInstruction;
     34 
     35 // DfsHloVisitor with default action based on the HloInstruction being visited.
     36 // Users should not use this class directly, but use the type aliases
     37 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
     38 //
     39 // Do *not* add an override to this class if the opcode is covered by
     40 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
     41 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
     42 // here will break passes which rely on the HandleElementwiseUnary/Binary
     43 // handling these opcodes.
     44 template <typename HloInstructionPtr>
     45 class DfsHloVisitorWithDefaultBase
     46     : public DfsHloVisitorBase<HloInstructionPtr> {
     47  public:
     48   DfsHloVisitorWithDefaultBase() {}
     49   ~DfsHloVisitorWithDefaultBase() override {}
     50 
     51   // Default action performed on HloInstruction.
     52   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
     53 
     54   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
     55     return DefaultAction(hlo);
     56   }
     57   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
     58     return DefaultAction(hlo);
     59   }
     60 
     61   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
     62     return DefaultAction(hlo);
     63   }
     64 
     65   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
     66     return DefaultAction(hlo);
     67   }
     68 
     69   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
     70     return DefaultAction(hlo);
     71   }
     72 
     73   Status HandleClamp(HloInstructionPtr clamp) override {
     74     return DefaultAction(clamp);
     75   }
     76   Status HandleConcatenate(HloInstructionPtr concatenate) override {
     77     return DefaultAction(concatenate);
     78   }
     79   Status HandleSelect(HloInstructionPtr select) override {
     80     return DefaultAction(select);
     81   }
     82   Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
     83     return DefaultAction(tuple_select);
     84   }
     85   Status HandleDot(HloInstructionPtr dot) override {
     86     return DefaultAction(dot);
     87   }
     88   Status HandleConvolution(HloInstructionPtr convolution) override {
     89     return DefaultAction(convolution);
     90   }
     91   Status HandleFft(HloInstructionPtr fft) override {
     92     return DefaultAction(fft);
     93   }
     94   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
     95     return DefaultAction(hlo);
     96   }
     97   Status HandleCholesky(HloInstructionPtr hlo) override {
     98     return DefaultAction(hlo);
     99   }
    100   Status HandleAllReduce(HloInstructionPtr crs) override {
    101     return DefaultAction(crs);
    102   }
    103   Status HandleAllToAll(HloInstructionPtr hlo) override {
    104     return DefaultAction(hlo);
    105   }
    106   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
    107     return DefaultAction(hlo);
    108   }
    109   Status HandleReplicaId(HloInstructionPtr hlo) override {
    110     return DefaultAction(hlo);
    111   }
    112   Status HandleRng(HloInstructionPtr random) override {
    113     return DefaultAction(random);
    114   }
    115   Status HandleInfeed(HloInstructionPtr infeed) override {
    116     return DefaultAction(infeed);
    117   }
    118   Status HandleOutfeed(HloInstructionPtr outfeed) override {
    119     return DefaultAction(outfeed);
    120   }
    121   Status HandleReverse(HloInstructionPtr reverse) override {
    122     return DefaultAction(reverse);
    123   }
    124   Status HandleSort(HloInstructionPtr sort) override {
    125     return DefaultAction(sort);
    126   }
    127   Status HandleConstant(HloInstructionPtr constant) override {
    128     return DefaultAction(constant);
    129   }
    130   Status HandleIota(HloInstructionPtr iota) override {
    131     return DefaultAction(iota);
    132   }
    133   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
    134     return DefaultAction(get_tuple_element);
    135   }
    136   Status HandleParameter(HloInstructionPtr parameter) override {
    137     return DefaultAction(parameter);
    138   }
    139   Status HandleFusion(HloInstructionPtr fusion) override {
    140     return DefaultAction(fusion);
    141   }
    142   Status HandleCall(HloInstructionPtr call) override {
    143     return DefaultAction(call);
    144   }
    145   Status HandleCustomCall(HloInstructionPtr custom_call) override {
    146     return DefaultAction(custom_call);
    147   }
    148   Status HandleSlice(HloInstructionPtr slice) override {
    149     return DefaultAction(slice);
    150   }
    151   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
    152     return DefaultAction(dynamic_slice);
    153   }
    154   Status HandleDynamicUpdateSlice(
    155       HloInstructionPtr dynamic_update_slice) override {
    156     return DefaultAction(dynamic_update_slice);
    157   }
    158   Status HandleTuple(HloInstructionPtr tuple) override {
    159     return DefaultAction(tuple);
    160   }
    161   Status HandleMap(HloInstructionPtr map) override {
    162     return DefaultAction(map);
    163   }
    164   Status HandleReduce(HloInstructionPtr reduce) override {
    165     return DefaultAction(reduce);
    166   }
    167   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
    168     return DefaultAction(reduce_window);
    169   }
    170   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
    171     return DefaultAction(select_and_scatter);
    172   }
    173   Status HandleBitcast(HloInstructionPtr bitcast) override {
    174     return DefaultAction(bitcast);
    175   }
    176   Status HandleBroadcast(HloInstructionPtr broadcast) override {
    177     return DefaultAction(broadcast);
    178   }
    179   Status HandlePad(HloInstructionPtr pad) override {
    180     return DefaultAction(pad);
    181   }
    182   Status HandleReshape(HloInstructionPtr reshape) override {
    183     return DefaultAction(reshape);
    184   }
    185   Status HandleTranspose(HloInstructionPtr transpose) override {
    186     return DefaultAction(transpose);
    187   }
    188   Status HandleWhile(HloInstructionPtr xla_while) override {
    189     return DefaultAction(xla_while);
    190   }
    191   Status HandleConditional(HloInstructionPtr conditional) override {
    192     return DefaultAction(conditional);
    193   }
    194   Status HandleRecv(HloInstructionPtr recv) override {
    195     return DefaultAction(recv);
    196   }
    197   Status HandleRecvDone(HloInstructionPtr recv_done) override {
    198     return DefaultAction(recv_done);
    199   }
    200   Status HandleSend(HloInstructionPtr send) override {
    201     return DefaultAction(send);
    202   }
    203   Status HandleSendDone(HloInstructionPtr send_done) override {
    204     return DefaultAction(send_done);
    205   }
    206   Status HandleGather(HloInstructionPtr gather) override {
    207     return DefaultAction(gather);
    208   }
    209   Status HandleScatter(HloInstructionPtr scatter) override {
    210     return DefaultAction(scatter);
    211   }
    212   Status HandleAfterAll(HloInstructionPtr token) override {
    213     return DefaultAction(token);
    214   }
    215   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
    216     return DefaultAction(get_size);
    217   }
    218   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
    219     return DefaultAction(add_dependency);
    220   }
    221 
    222   // Invoked to inform the visitor that the traversal has completed, and that
    223   // the root was "root".
    224   Status FinishVisit(HloInstructionPtr /*root*/) override {
    225     return Status::OK();
    226   }
    227 
    228  private:
    229   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
    230 };
    231 
    232 // Users should use these type aliases which are only two valid instantiations.
    233 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
    234 using ConstDfsHloVisitorWithDefault =
    235     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
    236 
    237 // (Const)FunctionVisitor lets you transform an
    238 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
    239 //
    240 // This is useful if you have code that needs to handle visitors in the form of
    241 // both std::function and DfsHloVisitor.  You can wrap the function in a
    242 // FunctionVisitor and then treat it like any other DfsHloVisitor.
    243 template <typename HloInstructionPtr>
    244 class FunctionVisitorBase
    245     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
    246  public:
    247   explicit FunctionVisitorBase(
    248       std::function<Status(HloInstructionPtr)> visitor_func)
    249       : visitor_func_(std::move(visitor_func)) {}
    250 
    251   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
    252     return visitor_func_(hlo_instruction);
    253   }
    254 
    255  private:
    256   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
    257 
    258   std::function<Status(HloInstructionPtr)> visitor_func_;
    259 };
    260 
    261 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
    262 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
    263 
    264 }  // namespace xla
    265 
    266 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
    267