Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
     18 
     19 #include "tensorflow/compiler/xla/literal_util.h"
     20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     22 #include "tensorflow/compiler/xla/types.h"
     23 #include "tensorflow/compiler/xla/xla_data.pb.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 
     32 class HloComputation;
     33 class HloInstruction;
     34 
     35 // DfsHloVisitor with default action based on the HloInstruction being visited.
     36 // Users should not use this class directly, but use the type aliases
     37 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
     38 template <typename HloInstructionPtr>
     39 class DfsHloVisitorWithDefaultBase
     40     : public DfsHloVisitorBase<HloInstructionPtr> {
     41  public:
     42   DfsHloVisitorWithDefaultBase() {}
     43   ~DfsHloVisitorWithDefaultBase() override {}
     44 
     45   // Default action performed on HloInstruction.
     46   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
     47 
     48   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
     49     return DefaultAction(hlo);
     50   }
     51   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
     52     return DefaultAction(hlo);
     53   }
     54 
     55   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
     56     return DefaultAction(hlo);
     57   }
     58 
     59   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
     60     return DefaultAction(hlo);
     61   }
     62 
     63   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
     64     return DefaultAction(hlo);
     65   }
     66 
     67   Status HandleClamp(HloInstructionPtr clamp) override {
     68     return DefaultAction(clamp);
     69   }
     70   Status HandleConcatenate(HloInstructionPtr concatenate) override {
     71     return DefaultAction(concatenate);
     72   }
     73   Status HandleConvert(HloInstructionPtr convert) override {
     74     return DefaultAction(convert);
     75   }
     76   Status HandleCopy(HloInstructionPtr copy) override {
     77     return DefaultAction(copy);
     78   }
     79   Status HandleSelect(HloInstructionPtr select) override {
     80     return DefaultAction(select);
     81   }
     82   Status HandleDot(HloInstructionPtr dot) override {
     83     return DefaultAction(dot);
     84   }
     85   Status HandleConvolution(HloInstructionPtr convolution) override {
     86     return DefaultAction(convolution);
     87   }
     88   Status HandleFft(HloInstructionPtr fft) override {
     89     return DefaultAction(fft);
     90   }
     91   Status HandleCrossReplicaSum(HloInstructionPtr crs) override {
     92     return DefaultAction(crs);
     93   }
     94   Status HandleCompare(HloInstructionPtr compare) override {
     95     return DefaultAction(compare);
     96   }
     97   Status HandleRng(HloInstructionPtr random) override {
     98     return DefaultAction(random);
     99   }
    100   Status HandleInfeed(HloInstructionPtr infeed) override {
    101     return DefaultAction(infeed);
    102   }
    103   Status HandleOutfeed(HloInstructionPtr outfeed) override {
    104     return DefaultAction(outfeed);
    105   }
    106   Status HandleHostCompute(HloInstructionPtr host_compute) override {
    107     return DefaultAction(host_compute);
    108   }
    109   Status HandleReverse(HloInstructionPtr reverse) override {
    110     return DefaultAction(reverse);
    111   }
    112   Status HandleSort(HloInstructionPtr sort) override {
    113     return DefaultAction(sort);
    114   }
    115   Status HandleConstant(HloInstructionPtr constant) override {
    116     return DefaultAction(constant);
    117   }
    118   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
    119     return DefaultAction(get_tuple_element);
    120   }
    121   Status HandleParameter(HloInstructionPtr parameter) override {
    122     return DefaultAction(parameter);
    123   }
    124   Status HandleFusion(HloInstructionPtr fusion) override {
    125     return DefaultAction(fusion);
    126   }
    127   Status HandleCall(HloInstructionPtr call) override {
    128     return DefaultAction(call);
    129   }
    130   Status HandleCustomCall(HloInstructionPtr custom_call) override {
    131     return DefaultAction(custom_call);
    132   }
    133   Status HandleSlice(HloInstructionPtr slice) override {
    134     return DefaultAction(slice);
    135   }
    136   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
    137     return DefaultAction(dynamic_slice);
    138   }
    139   Status HandleDynamicUpdateSlice(
    140       HloInstructionPtr dynamic_update_slice) override {
    141     return DefaultAction(dynamic_update_slice);
    142   }
    143   Status HandleTuple(HloInstructionPtr tuple) override {
    144     return DefaultAction(tuple);
    145   }
    146   Status HandleMap(HloInstructionPtr map) override {
    147     return DefaultAction(map);
    148   }
    149   Status HandleReduce(HloInstructionPtr reduce) override {
    150     return DefaultAction(reduce);
    151   }
    152   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
    153     return DefaultAction(reduce_window);
    154   }
    155   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
    156     return DefaultAction(select_and_scatter);
    157   }
    158   Status HandleBitcast(HloInstructionPtr bitcast) override {
    159     return DefaultAction(bitcast);
    160   }
    161   Status HandleBroadcast(HloInstructionPtr broadcast) override {
    162     return DefaultAction(broadcast);
    163   }
    164   Status HandlePad(HloInstructionPtr pad) override {
    165     return DefaultAction(pad);
    166   }
    167   Status HandleReshape(HloInstructionPtr reshape) override {
    168     return DefaultAction(reshape);
    169   }
    170   Status HandleTranspose(HloInstructionPtr transpose) override {
    171     return DefaultAction(transpose);
    172   }
    173   Status HandleWhile(HloInstructionPtr xla_while) override {
    174     return DefaultAction(xla_while);
    175   }
    176   Status HandleConditional(HloInstructionPtr conditional) override {
    177     return DefaultAction(conditional);
    178   }
    179   Status HandleRecv(HloInstructionPtr recv) override {
    180     return DefaultAction(recv);
    181   }
    182   Status HandleRecvDone(HloInstructionPtr recv_done) override {
    183     return DefaultAction(recv_done);
    184   }
    185   Status HandleSend(HloInstructionPtr send) override {
    186     return DefaultAction(send);
    187   }
    188   Status HandleSendDone(HloInstructionPtr send_done) override {
    189     return DefaultAction(send_done);
    190   }
    191   Status HandleGather(HloInstructionPtr gather) override {
    192     return DefaultAction(gather);
    193   }
    194 
    195   // Invoked to inform the visitor that the traversal has completed, and that
    196   // the root was "root".
    197   Status FinishVisit(HloInstructionPtr /*root*/) override {
    198     return Status::OK();
    199   }
    200 
    201  private:
    202   TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase);
    203 };
    204 
    205 // Users should use these type aliases which are only two valid instantiations.
    206 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
    207 using ConstDfsHloVisitorWithDefault =
    208     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
    209 
    210 // (Const)FunctionVisitor lets you transform an
    211 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
    212 //
    213 // This is useful if you have code that needs to handle visitors in the form of
    214 // both std::function and DfsHloVisitor.  You can wrap the function in a
    215 // FunctionVisitor and then treat it like any other DfsHloVisitor.
    216 template <typename HloInstructionPtr>
    217 class FunctionVisitorBase
    218     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
    219  public:
    220   explicit FunctionVisitorBase(
    221       std::function<Status(HloInstructionPtr)> visitor_func)
    222       : visitor_func_(std::move(visitor_func)) {}
    223 
    224   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
    225     return visitor_func_(hlo_instruction);
    226   }
    227 
    228  private:
    229   TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase);
    230 
    231   std::function<Status(HloInstructionPtr)> visitor_func_;
    232 };
    233 
    234 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
    235 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
    236 
    237 }  // namespace xla
    238 
    239 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
    240