Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // HLO instructions are in DAG form and represent the computations that the user
     17 // has built up via the XLA service interface. They are ultimately lowered
     18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
     19 // form; e.g. see DfsHloVisitor.
     20 
     21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
     22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
     23 
     24 #include <functional>
     25 #include <iosfwd>
     26 #include <list>
     27 #include <memory>
     28 #include <set>
     29 #include <string>
     30 #include <tuple>
     31 #include <unordered_map>
     32 #include <unordered_set>
     33 #include <vector>
     34 
     35 #include "tensorflow/compiler/xla/iterator_util.h"
     36 #include "tensorflow/compiler/xla/literal_util.h"
     37 #include "tensorflow/compiler/xla/map_util.h"
     38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     39 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     41 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
     42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     43 #include "tensorflow/compiler/xla/types.h"
     44 #include "tensorflow/compiler/xla/xla_data.pb.h"
     45 #include "tensorflow/core/lib/core/status.h"
     46 #include "tensorflow/core/lib/core/stringpiece.h"
     47 #include "tensorflow/core/lib/gtl/array_slice.h"
     48 #include "tensorflow/core/lib/gtl/flatmap.h"
     49 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     50 #include "tensorflow/core/lib/gtl/iterator_range.h"
     51 #include "tensorflow/core/platform/logging.h"
     52 #include "tensorflow/core/platform/macros.h"
     53 #include "tensorflow/core/platform/types.h"
     54 
     55 namespace xla {
     56 
     57 class HloComputation;
     58 class HloModule;
     59 
     60 // A bunch of switches that control how the hlo text should be printed.
     61 class HloPrintOptions {
     62  public:
     63   // Constructs the default print options: don't print large constants, don't
     64   // compact operands, no indentation.
     65   HloPrintOptions()
     66       : print_large_constants_(false),
     67         print_subcomputation_references_(true),
     68         print_metadata_(true),
     69         compact_operands_(false),
     70         print_operand_shape_(true),
     71         print_program_shape_(true),
     72         print_percent_(true),
     73         indent_amount_(0) {}
     74 
     75   static HloPrintOptions ShortParsable() {
     76     return HloPrintOptions()
     77         .set_print_large_constants(true)
     78         .set_print_subcomputation_references(true)
     79         .set_print_metadata(false)
     80         .set_print_operand_shape(false)
     81         .set_print_program_shape(false)
     82         .set_print_percent(false);
     83   }
     84 
     85   // If true, large constants will be printed out.
     86   HloPrintOptions& set_print_large_constants(bool value) {
     87     print_large_constants_ = value;
     88     return *this;
     89   }
     90 
     91   // If true, the names of subcomputations (e.g. a fusion node's fused
     92   // computation) won't be printed.  This makes the resulting text not parsable.
     93   //
     94   // A CustomCall's call target is printed even if
     95   // print_subcomputation_references is false, because the call target isn't an
     96   // HloComputation.
     97   HloPrintOptions& set_print_subcomputation_references(bool value) {
     98     print_subcomputation_references_ = value;
     99     return *this;
    100   }
    101 
    102   // If true, metatdata will be printed.
    103   HloPrintOptions& set_print_metadata(bool value) {
    104     print_metadata_ = value;
    105     return *this;
    106   }
    107 
    108   // If true, operands' shapes will be printed.
    109   HloPrintOptions& set_print_operand_shape(bool value) {
    110     print_operand_shape_ = value;
    111     return *this;
    112   }
    113 
    114   // If true, program shape of hlo computations will be printed.
    115   HloPrintOptions& set_print_program_shape(bool value) {
    116     print_program_shape_ = value;
    117     return *this;
    118   }
    119 
    120   // If true, names will be printed with prefix '%'.
    121   HloPrintOptions& set_print_percent(bool value) {
    122     print_percent_ = value;
    123     return *this;
    124   }
    125 
    126   // If true, only a part of operands will be printed out, and their names will
    127   // be omitted (note that in this case the text will not be parsable).
    128   HloPrintOptions& set_compact_operands(bool value) {
    129     compact_operands_ = value;
    130     return *this;
    131   }
    132 
    133   // The indent of the hlo text block.
    134   HloPrintOptions& set_indent_amount(int value) {
    135     indent_amount_ = value;
    136     return *this;
    137   }
    138 
    139   bool print_large_constants() const { return print_large_constants_; }
    140   bool print_subcomputation_references() const {
    141     return print_subcomputation_references_;
    142   }
    143   bool print_metadata() const { return print_metadata_; }
    144   bool compact_operands() const { return compact_operands_; }
    145   bool print_operand_shape() const { return print_operand_shape_; }
    146   bool print_program_shape() const { return print_program_shape_; }
    147   bool print_percent() const { return print_percent_; }
    148   int indent_amount() const { return indent_amount_; }
    149 
    150  private:
    151   bool print_large_constants_;
    152   bool print_subcomputation_references_;
    153   bool print_metadata_;
    154   bool compact_operands_;
    155   bool print_operand_shape_;
    156   bool print_program_shape_;
    157   bool print_percent_;
    158   int indent_amount_;
    159 };
    160 
    161 // HLO instructions are the IR used by the high-level compiler.
    162 class HloInstruction {
    163  public:
    164   enum class FusionKind {
    165     kLoop,          // Fused into a loop.
    166     kInput,         // Op's input is fused into the op itself.
    167     kOutput,        // Op's output is fused into the op itself.
    168                     // REQUIRES: At least one operand buffer must be able
    169                     // to alias the output buffer.
    170     kTransposeDot,  // Fused into a dot with transposed operands.
    171     kCustom,        // Custom category for backend-specific fusions that
    172                     // do not match any of the more specific ones.
    173   };
    174 
    175   ~HloInstruction();
    176 
    177   // Creates an instruction from the given proto. Arguments:
    178   //
    179   //   module: the module which will contain the instruction. The newly created
    180   //     instruction is *not* added to the module or any computation, however.
    181   //   proto: the proto to convert from.
    182   //   instruction_map: a map from instruction name to HloInstruction*. This map
    183   //     must contain all operands of the newly constructed instruction.
    184   //   computation_map: a map from computation name to HloComputation*. This map
    185   //     must contain all computations which the newly constructed instruction
    186   //     calls.
    187   //   add_fused_computation: A function to call to add a fused
    188   //     computation. Used (clearly) when the instruction is a fusion
    189   //     instruction.
    190   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
    191       HloModule* module, const HloInstructionProto& proto,
    192       const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
    193       const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
    194       const std::function<void(std::unique_ptr<HloComputation>)>&
    195           add_fused_computation);
    196 
    197   // Creates a parameter-retrieving instruction.
    198   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
    199                                                          const Shape& shape,
    200                                                          const string& name);
    201 
    202   // Creates a literal constant instruction.
    203   static std::unique_ptr<HloInstruction> CreateConstant(
    204       std::unique_ptr<Literal> literal);
    205 
    206   // Creates a get tuple element instruction.
    207   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
    208       const Shape& shape, HloInstruction* operand, int64 index);
    209 
    210   // Creates a trace instruction that logs the input operand in the computation.
    211   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
    212                                                      HloInstruction* operand);
    213 
    214   // Creates a random number generation instruction that fills a shape with
    215   // random numbers from a given distribution.
    216   static std::unique_ptr<HloInstruction> CreateRng(
    217       const Shape& shape, RandomDistribution distribution,
    218       tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
    219 
    220   // Creates a unary instruction (one operand).
    221   // Precondition: opcode must be a legitimate unary operation.
    222   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
    223                                                      HloOpcode opcode,
    224                                                      HloInstruction* operand);
    225 
    226   // Creates a binary instruction (two operands).
    227   // Precondition: opcode must be a legitimate binary operation.
    228   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
    229                                                       HloOpcode opcode,
    230                                                       HloInstruction* lhs,
    231                                                       HloInstruction* rhs);
    232 
    233   // Creates a ternary instruction (three operands).
    234   // Precondition: opcode must be a legitimate ternary operation.
    235   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
    236                                                        HloOpcode opcode,
    237                                                        HloInstruction* lhs,
    238                                                        HloInstruction* rhs,
    239                                                        HloInstruction* ehs);
    240 
    241   // Creates a variadic instruction (variable number of operands).
    242   // Precondition: opcode must be a legitimate variadic operation.
    243   static std::unique_ptr<HloInstruction> CreateVariadic(
    244       const Shape& shape, HloOpcode opcode,
    245       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
    246 
    247   // Creates a map instruction, where the computation (given by the handle) is
    248   // applied element-wise to every element in operands (across the operands,
    249   // at a given index) with the same `static_operands`.
    250   static std::unique_ptr<HloInstruction> CreateMap(
    251       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    252       HloComputation* map_computation,
    253       tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
    254 
    255   // Creates a convolution op, where rhs is the convolutional filter
    256   // and window describes how the filter is applied to lhs.
    257   static std::unique_ptr<HloInstruction> CreateConvolve(
    258       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    259       const Window& window,
    260       const ConvolutionDimensionNumbers& dimension_numbers);
    261 
    262   // Creates an FFT op, of the type indicated by fft_type.
    263   static std::unique_ptr<HloInstruction> CreateFft(
    264       const Shape& shape, HloInstruction* operand, FftType fft_type,
    265       tensorflow::gtl::ArraySlice<int64> fft_length);
    266 
    267   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
    268   // dimensions specified in 'dimension_numbers'.
    269   static std::unique_ptr<HloInstruction> CreateDot(
    270       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    271       const DotDimensionNumbers& dimension_numbers);
    272 
    273   // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
    274   // of the LHS with dimension 0 of the RHS with no batch dimensions.  Both LHS
    275   // and the RHS must be of rank 2.
    276   static std::unique_ptr<HloInstruction> CreateCanonicalDot(
    277       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
    278 
    279   // Creates a reduce-precision op, where operand is the data to reduce in
    280   // precision, and exponent_bits and mantissa_bits describe the precision to
    281   // reduce it to.
    282   static std::unique_ptr<HloInstruction> CreateReducePrecision(
    283       const Shape& shape, HloInstruction* operand, const int exponent_bits,
    284       const int mantissa_bits);
    285 
    286   // Creates a cross replica sum op.
    287   static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
    288       const Shape& shape,
    289       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
    290 
    291   // Creates a conversion instruction, where operand is the data to convert and
    292   // shape is the target shape for the conversion.
    293   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
    294                                                        HloInstruction* operand);
    295 
    296   // Creates a bitcast conversion instruction, where operand is the data to
    297   // convert and shape is the target shape for the conversion.
    298   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
    299       const Shape& shape, HloInstruction* operand);
    300 
    301   // Creates an infeed instruction, which reads data of the given shape from the
    302   // Infeed interface of the device.
    303   static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
    304                                                       const string& config);
    305 
    306   // Creates an outfeed instruction, which outputs data.
    307   static std::unique_ptr<HloInstruction> CreateOutfeed(
    308       const Shape& shape, HloInstruction* operand,
    309       tensorflow::StringPiece outfeed_config);
    310 
    311   // Creates an asynchronous send instruction with the given channel id, which
    312   // initiates sending the operand data to a unique receive instruction in
    313   // another computation that has the same channel id.
    314   static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
    315                                                     int64 channel_id);
    316 
    317   // Blocks until data transfer for the Send instruction (operand) is complete.
    318   // The operand must be kSend.
    319   static std::unique_ptr<HloInstruction> CreateSendDone(
    320       HloInstruction* operand);
    321 
    322   // Creates an asynchronous receive instruction with the given channel id,
    323   // which allocates resources to receive data of the given shape from a unique
    324   // send instruction in another computation that has the same channel id.
    325   static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
    326                                                     int64 channel_id);
    327 
    328   // Blocks until data transfer for the Recv instruction (operand) is complete
    329   // and returns the receive buffer. The operand must be kRecv.
    330   static std::unique_ptr<HloInstruction> CreateRecvDone(
    331       HloInstruction* operand);
    332 
    333   // Creates a slice instruction, where the operand is sliced by the given
    334   // start/limit indices.
    335   static std::unique_ptr<HloInstruction> CreateSlice(
    336       const Shape& shape, HloInstruction* operand,
    337       tensorflow::gtl::ArraySlice<int64> start_indices,
    338       tensorflow::gtl::ArraySlice<int64> limit_indices,
    339       tensorflow::gtl::ArraySlice<int64> strides);
    340 
    341   // Creates a slice instruction, where the first operand is sliced by
    342   // start indices specified in the second operand, and by size specified in
    343   // 'slice_sizes'.
    344   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
    345       const Shape& shape, HloInstruction* operand,
    346       HloInstruction* start_indices,
    347       tensorflow::gtl::ArraySlice<int64> slice_sizes);
    348 
    349   // Creates a dynamic update slice instruction, which updates a slice
    350   // of 'operand' with 'update' and 'start_indices'.
    351   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
    352       const Shape& shape, HloInstruction* operand, HloInstruction* update,
    353       HloInstruction* start_indices);
    354 
    355   // Creates a concatenate instruction, where the operands are concatenated on
    356   // the provided dimension.
    357   static std::unique_ptr<HloInstruction> CreateConcatenate(
    358       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    359       int64 dimension);
    360 
    361   // Creates a reduce instruction, where the computation (given by the handle)
    362   // is applied successively to every element in operand. That is, if f is the
    363   // function to apply (which either takes 2 [accumulator, value] or 3
    364   // [accumulator, index, value] arguments) and init is a reduction operator
    365   // specified initial value (for example, 0 for addition), then this operation
    366   // will compute:
    367   //   f(f(init, [index0], value0), [index1], value1), ...)
    368   static std::unique_ptr<HloInstruction> CreateReduce(
    369       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
    370       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
    371       HloComputation* reduce_computation);
    372 
    373   // Creates a reduce-window instruction, where the computation (given
    374   // by the handle) is applied window-wise at each valid window
    375   // position in the operand.
    376   static std::unique_ptr<HloInstruction> CreateReduceWindow(
    377       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
    378       const Window& window, HloComputation* reduce_computation);
    379 
    380   // Creates a batch-norm-training instruction.
    381   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
    382       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    383       HloInstruction* offset, float epsilon, int64 feature_index);
    384 
    385   // Creates a batch-norm-inference instruction.
    386   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
    387       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    388       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
    389       float epsilon, int64 feature_index);
    390 
    391   // Creates a batch-norm-grad instruction.
    392   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
    393       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    394       HloInstruction* mean, HloInstruction* variance,
    395       HloInstruction* grad_output, float epsilon, int64 feature_index);
    396 
    397   // Creates a scatter computation that scatters the `source` array to the
    398   // selected indices of each window.
    399   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
    400       const Shape& shape, HloInstruction* operand, HloComputation* select,
    401       const Window& window, HloInstruction* source, HloInstruction* init_value,
    402       HloComputation* scatter);
    403 
    404   // Creates a broadcast instruction.
    405   static std::unique_ptr<HloInstruction> CreateBroadcast(
    406       const Shape& shape, HloInstruction* operand,
    407       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
    408 
    409   // Creates a sequence of instructions that performs an explicit broadcast of
    410   // the operand to the target shape.
    411   //
    412   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
    413   // returned as a unique_ptr for API consistency with other factory methods in
    414   // this interface.
    415   //
    416   // TODO(b/72173833) Ideally HloComputations would always be present, and so
    417   // the adder being passed by the caller would not be necessary.
    418   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
    419       const Shape& output_shape, HloInstruction* operand,
    420       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
    421           adder);
    422 
    423   // Creates a pad instruction, where the operand is padded on the edges and
    424   // between the elements with the given padding value.
    425   static std::unique_ptr<HloInstruction> CreatePad(
    426       const Shape& shape, HloInstruction* operand,
    427       HloInstruction* padding_value, const PaddingConfig& padding_config);
    428 
    429   // Creates a reshape instruction, where the operand is flattened row-major
    430   // order and then reshaped to the given result shape.
    431   static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
    432                                                        HloInstruction* operand);
    433 
    434   // Creates a transpose instruction which permutes the operand dimensions.
    435   static std::unique_ptr<HloInstruction> CreateTranspose(
    436       const Shape& shape, HloInstruction* operand,
    437       tensorflow::gtl::ArraySlice<int64> dimensions);
    438 
    439   // Creates a while instruction, given a condition computation, a body
    440   // computation, and the initial value for the input of the computations. For
    441   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
    442   // corresponds to the C code below.
    443   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
    444   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
    445                                                      HloComputation* condition,
    446                                                      HloComputation* body,
    447                                                      HloInstruction* init);
    448 
    449   static std::unique_ptr<HloInstruction> CreateConditional(
    450       const Shape& shape, HloInstruction* pred,
    451       HloInstruction* true_computation_arg, HloComputation* true_computation,
    452       HloInstruction* false_computation_arg, HloComputation* false_computation);
    453 
    454   static std::unique_ptr<HloInstruction> CreateGather(
    455       const Shape& shape, HloInstruction* operand,
    456       HloInstruction* gather_indices,
    457       const GatherDimensionNumbers& gather_dim_numbers,
    458       tensorflow::gtl::ArraySlice<int64> window_bounds);
    459 
    460   // Creates a fusion instruction. A fusion instruction contains one or more
    461   // fused instructions forming an expression with a single root
    462   // "fused_root". Additional instructions can be added to the fusion
    463   // instruction with the method FuseInstruction.
    464   static std::unique_ptr<HloInstruction> CreateFusion(
    465       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
    466 
    467   static std::unique_ptr<HloInstruction> CreateFusion(
    468       const Shape& shape, FusionKind fusion_kind,
    469       tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    470       HloComputation* fusion_computation);
    471 
    472   // Creates a call instruction that applies the given computation on the given
    473   // operands. "shape" is the resultant shape.
    474   static std::unique_ptr<HloInstruction> CreateCall(
    475       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    476       HloComputation* computation);
    477 
    478   // Creates a custom call instruction that applies the given custom call target
    479   // to the given operands. "shape" is the resultant shape.
    480   static std::unique_ptr<HloInstruction> CreateCustomCall(
    481       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    482       tensorflow::StringPiece custom_call_target);
    483 
    484   // Creates a HostCompute instruction, which records host-side control and
    485   // data dependencies for use in instruction scheduling.
    486   static std::unique_ptr<HloInstruction> CreateHostCompute(
    487       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    488       tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
    489 
    490   // Creates a tuple instruction with the given elements. This is a convenience
    491   // wrapper around CreateVariadic.
    492   static std::unique_ptr<HloInstruction> CreateTuple(
    493       tensorflow::gtl::ArraySlice<HloInstruction*> elements);
    494 
    495   // Creates a reverse instruction, which reverses the order of the elements
    496   // in the specified dimensions.
    497   static std::unique_ptr<HloInstruction> CreateReverse(
    498       const Shape& shape, HloInstruction* operand,
    499       tensorflow::gtl::ArraySlice<int64> dimensions);
    500 
    501   // Creates an instance of GatherDimensionNumbers.
    502   static GatherDimensionNumbers MakeGatherDimNumbers(
    503       tensorflow::gtl::ArraySlice<int64> output_window_dims,
    504       tensorflow::gtl::ArraySlice<int64> elided_window_dims,
    505       tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
    506 
    507   // Returns the opcode for this instruction.
    508   HloOpcode opcode() const { return opcode_; }
    509 
    510   // Returns true if this instruction has a side effect. An instruction has a
    511   // side effect if it uses certain opcodes or calls a computation with a side
    512   // effect.
    513   bool HasSideEffect() const;
    514 
    515   // Returns the result shape of this instruction.
    516   const Shape& shape() const;
    517 
    518   // Returns the (mutable) result shape of this instruction.
    519   Shape* mutable_shape() { return &shape_; }
    520 
    521   // Returns the ith operand to this instruction.
    522   const HloInstruction* operand(int64 i) const;
    523 
    524   // Returns the ith operand to this instruction.
    525   HloInstruction* mutable_operand(int64 i);
    526 
    527   // Returns the number of operands to this instruction.
    528   int64 operand_count() const { return operands_.size(); }
    529 
    530   // Returns the vector of operands of this instruction.
    531   using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
    532   const InstructionVector& operands() const { return operands_; }
    533 
    534   // Returns the index of 'target' in the operands sequence.
    535   // Precondition: target must be an operand (or a fatal error will occur).
    536   int64 operand_index(const HloInstruction* target) const;
    537 
    538   // Returns the number of users of this instruction.
    539   int64 user_count() const { return users_.size(); }
    540 
    541   // Returns the users of this instruction.
    542   const std::vector<HloInstruction*>& users() const { return users_; }
    543 
    544   // Returns true if this instruction is a user of 'instruction'.
    545   bool IsUserOf(const HloInstruction* instruction) const {
    546     return ContainsKey(instruction->user_set_, this);
    547   }
    548 
    549   // Adds a control dependency from this instruction to the given
    550   // instruction. This instruction becomes a control predecessor of
    551   // 'instruction', and 'instruction' becomes a control successor of this
    552   // instruction. Returns an error status if either of the given instructions
    553   // does not belong to the same computation.
    554   //
    555   // This is used to enforce an additional ordering requirement that is not
    556   // captured by normal data dependencies, such as ordering among Send or Recv
    557   // operations to avoid deadlock.
    558   Status AddControlDependencyTo(HloInstruction* instruction);
    559 
    560   // Removes a previously added control dependency from this instruction to
    561   // 'instruction'.
    562   Status RemoveControlDependencyTo(HloInstruction* instruction);
    563 
    564   // Returns the set of control predecessors (successors) of this
    565   // instruction. Control predecessors (successors) must execute before (after)
    566   // the current instruction.
    567   const std::vector<HloInstruction*>& control_predecessors() const {
    568     return control_predecessors_;
    569   }
    570   const std::vector<HloInstruction*>& control_successors() const {
    571     return control_successors_;
    572   }
    573 
    574   // Returns true if "other" performs the same computation as this instruction.
    575   bool Identical(
    576       const HloInstruction& other,
    577       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
    578           eq_operands = std::equal_to<const HloInstruction*>(),
    579       const std::function<bool(const HloComputation*, const HloComputation*)>&
    580           eq_computations = std::equal_to<const HloComputation*>(),
    581       bool layout_sensitive = true) const {
    582     // An instruction is always identical to itself.
    583     if (this == &other) {
    584       return true;
    585     }
    586 
    587     // Identical instruction must have the same opcode, shape, and identical
    588     // operands.
    589     if (opcode() != other.opcode()) {
    590       return false;
    591     }
    592     using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
    593     EqShapeFuncType eq_shapes =
    594         layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
    595     if (!eq_shapes(shape(), other.shape())) {
    596       return false;
    597     }
    598     if (operands().size() != other.operands().size()) {
    599       return false;
    600     }
    601 
    602     // Use an explicit loop rather than ContainerEquals, because copying around
    603     // std::functions may be too expensive in some cases.
    604     for (size_t i = 0; i < operands().size(); ++i) {
    605       if (!eq_operands(operand(i), other.operand(i))) {
    606         return false;
    607       }
    608     }
    609 
    610     return IdenticalSlowPath(other, eq_computations, eq_shapes);
    611   }
    612 
    613   // Returns whether the instruction has a constant operand.
    614   bool HasConstantOperand() const;
    615 
    616   // Returns whether this instruction does a rank-2 transposition.
    617   bool IsRank2Transpose() const;
    618 
    619   // Replaces the use of this instruction in "user" with "new_producer". Note
    620   // that there might be multiple uses of this instruction in "user"; all will
    621   // be replaced.
    622   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
    623 
    624   // Replaces the specified operand with new_operand.
    625   Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
    626 
    627   // Replaces all uses of this instruction with the new producer. If
    628   // new_producer is a user of this instruction then new_producer remains a use
    629   // of this instruction to avoid introducing cycles into the graph.
    630   //
    631   // If this instruction is the root of its computation, sets the computation's
    632   // root to new_producer.
    633   Status ReplaceAllUsesWith(HloInstruction* new_producer);
    634 
    635   // Detaches an instruction from its operands. That is, remove the instruction
    636   // from each operand's user set. This should only be called prior to
    637   // deallocating the instruction.
    638   void DetachFromOperands();
    639 
    640   // Performs a postorder DFS visit using this node as the root. If
    641   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
    642   // complete. If ignore_control_predecessors is true, instructions only
    643   // reachable via control dependencies will not be visited, and the postorder
    644   // will not take control dependencies into account. It is as if the control
    645   // dependencies didn't exist in the graph at all.
    646   template <typename HloInstructionPtr>
    647   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
    648                 bool call_finish_visit = true,
    649                 bool ignore_control_predecessors = false);
    650   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
    651                 bool ignore_control_predecessors = false) const {
    652     return const_cast<HloInstruction*>(this)->Accept(
    653         visitor, call_finish_visit, ignore_control_predecessors);
    654   }
    655 
    656   // Same as Accept() above, but the order of operand and control predecessor
    657   // visitation is determined by the given operand order; if compare(A, B) ==
    658   // true, A is visited before B.
    659   using CompareFunction =
    660       std::function<bool(const HloInstruction*, const HloInstruction*)>;
    661   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
    662                                 const CompareFunction& operand_order,
    663                                 bool call_finish_visit = true);
    664 
    665   // Performs a postorder DFS visit using this node as the root. Calls the given
    666   // visitor function at each instruction.
    667   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
    668   Status Accept(
    669       const std::function<Status(const HloInstruction*)>& visitor_func) const;
    670 
    671   // Visits all instructions rooted at this instruction using the given visitor
    672   // in the given order. 'order' must contain at least the set of instructions
    673   // rooted at this node (ie, those accessible from a DFS traversal from this
    674   // instruction). Instructions contained in 'order' which are not in the set of
    675   // instructions rooted at this node are ignored. 'order' must also be a valid
    676   // topological sort of these instructions (defs appear before uses) though
    677   // need not be a DFS post-order.
    678   Status AcceptOrdered(DfsHloVisitor* visitor,
    679                        const std::vector<const HloInstruction*>& order);
    680 
    681   // Visit this instruction and only this instruction with the given visitor.
    682   template <typename HloInstructionPtr>
    683   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
    684 
    685   // Returns the literal associated with this instruction.
    686   //
    687   // Note: only constant and parameter opcodes have an associated literal.
    688   const Literal& literal() const;
    689 
    690   // Returns the parameter number associated with this instruction.
    691   //
    692   // Note: only parameter opcodes have an associated parameter number.
    693   int64 parameter_number() const {
    694     CHECK_EQ(HloOpcode::kParameter, opcode_);
    695     return parameter_number_;
    696   }
    697 
    698   // Returns the dimension sizes or numbers associated with this instruction.
    699   //
    700   // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
    701   // and reverse.
    702   const std::vector<int64>& dimensions() const;
    703   int64 dimensions(int64 index) const;
    704 
    705   // Accessor for the dimension in which a concatenate HLO should occur.
    706   // Precondition: opcode() == HloOpcode::kConcatenate
    707   int64 concatenate_dimension() const;
    708 
    709   // Returns the tuple index associated with this instruction.
    710   //
    711   // Precondition: opcode() == HloOpcode::kGetTupleElement
    712   int64 tuple_index() const;
    713 
    714   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
    715   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
    716   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
    717   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
    718       const;
    719 
    720   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
    721     auto rv =
    722         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
    723     return {const_cast<HloInstruction*>(rv.first), rv.second};
    724   }
    725 
    726   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
    727   const HloInstruction* LatestNonGteAncestor() const;
    728 
    729   HloInstruction* LatestNonGteAncestor() {
    730     return const_cast<HloInstruction*>(
    731         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
    732   }
    733 
    734   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
    735   // The setter should only be called by HloModule or HloComputation methods.
    736   //
    737   // Precondition: The instruction has a valid to_apply_ field.
    738   HloComputation* to_apply() const;
    739   void set_to_apply(HloComputation* to_apply);
    740 
    741   // Returns the custom_call_target for CustomCall.
    742   // Precondition: opcode() == HloOpcode::kCustomCall
    743   const string& custom_call_target() const;
    744 
    745   // Returns the config for the Outfeed instruction.
    746   // Precondition: opcode() == HloOpcode::kOutfeed
    747   const string& outfeed_config() const;
    748 
    749   // Returns the shape for the Outfeed instruction.
    750   // Precondition: opcode() == HloOpcode::kOutfeed
    751   const Shape& outfeed_shape() const;
    752 
    753   // Gets/sets the while_condition or while_body HloComputation for While. The
    754   // setters should only be called by HloModule or HloComputation methods.
    755   //
    756   // Precondition: The instruction is a While instruction.
    757   HloComputation* while_condition() const;
    758   HloComputation* while_body() const;
    759   void set_while_condition(HloComputation* while_condition);
    760   void set_while_body(HloComputation* while_body);
    761 
    762   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
    763   // setters should only be called by HloModule or HloComputation methods.
    764   //
    765   // Precondition: opcode() == HloOpcode::kSelectAndScatter.
    766   HloComputation* select() const;
    767   HloComputation* scatter() const;
    768   void set_select(HloComputation* select);
    769   void set_scatter(HloComputation* scatter);
    770 
    771   // Gets/sets the true and false HloComputation for Conditional. The setters
    772   // should only be called by HloModule or HloComputation methods.
    773   //
    774   // Precondition: The instruction is a Conditional instruction.
    775   HloComputation* true_computation() const;
    776   HloComputation* false_computation() const;
    777   void set_true_computation(HloComputation* true_computation);
    778   void set_false_computation(HloComputation* false_computation);
    779 
    780   // Returns a string for the signature of this instruction if considered as a
    781   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
    782   string SignatureString() const;
    783 
    784   // Returns a debugging string that represents this instruction.
    785   //
    786   // (We express the default options using an overload rather than a default
    787   // param because gdb ignores default params, but does resolve overloads.)
    788   //
    789   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
    790   // default, backing off on providing full information for very large strings,
    791   // or provide a different name for a ToString-like function that does that.
    792   string ToString() const { return ToString(HloPrintOptions()); }
    793   string ToString(const HloPrintOptions& options) const;
    794 
    795   // Components of the ToString() representation:
    796 
    797   // Returns a string representation of the operand list.
    798   string OperandsToString(const HloPrintOptions& options) const;
    799 
    800   // Returns string representation of op-specific attributes.
    801   std::vector<string> ExtraAttributesToString(
    802       const HloPrintOptions& options) const;
    803 
    804   // As ToString, but returns a shorter string.
    805   string ToShortString() const;
    806 
    807   // Returns a serialized representation of this instruction.
    808   HloInstructionProto ToProto() const;
    809 
    810   // Returns a category for the HLO. This could be something like "convolution"
    811   // or "elementwise".
    812   string ToCategory() const;
    813 
    814   // Returns a logging instruction, if the output of this instruction is logged.
    815   //
    816   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
    817   HloInstruction* tracing() const;
    818   void set_tracing(HloInstruction* trace_instruction);
    819 
    820   // Returns the channel id associated with the instruction. The id is
    821   // shared between each Send/Recv pair and is globally unique to identify each
    822   // channel.
    823   //
    824   // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
    825   int64 channel_id() const { return channel_id_; }
    826 
    827   // Returns feature_index field associated with the instruction. The index
    828   // represents the index of the feature dimension.
    829   //
    830   // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
    831   // or kBatchNormGrad.
    832   int64 feature_index() const { return feature_index_; }
    833 
    834   // Returns a epsilon value associated with the instruction. The is a small
    835   // number added to the variance to avoid divide-by-zero error.
    836   //
    837   // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
    838   // or kBatchNormGrad.
    839   float epsilon() const { return epsilon_; }
    840 
    841   // Returns the infeed configuration string. The infeed configuration includes
    842   // any metadata needed for the backend compiler (e.g., infeed buffer address)
    843   // and is target-dependent.
    844   string infeed_config() const { return infeed_config_; }
    845   void set_infeed_config(const string& config) { infeed_config_ = config; }
    846 
    847   // Returns a tag to be used in tracing.
    848   //
    849   // Precondition: opcode() == HloOpcode::kTrace
    850   string TracingTag() const;
    851 
    852   // Returns whether the instruction is a constant.
    853   bool IsConstant() const;
    854 
    855   // Returns true if this instruction is fused, ie contained within a fusion
    856   // instruction.
    857   bool IsFused() const;
    858 
    859   // Returns the computation for this fused instruction.
    860   //
    861   // Precondition: opcode() == HloOpcode::kFusion
    862   HloComputation* fused_instructions_computation() const;
    863 
    864   // Returns true if this instruction can be legally fused into a fusion
    865   // instruction.
    866   bool IsFusable() const;
    867 
    868   // Returns the root instruction of the fused expression contained within this
    869   // fusion instruction.
    870   //
    871   // Precondition: opcode() == HloOpcode::kFusion
    872   HloInstruction* fused_expression_root() const;
    873 
    874   // Returns the list of fused instructions inside this fusion instruction.  The
    875   // returned type is a range of HloInstruction*s.
    876   //
    877   // Precondition: opcode() == HloOpcode::kFusion
    878   const tensorflow::gtl::iterator_range<UnwrappingIterator<
    879       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
    880   fused_instructions() const;
    881 
    882   const tensorflow::gtl::iterator_range<
    883       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
    884   fused_instructions();
    885 
    886   // Gets the number of instructions inside this fusion instruction.
    887   //
    888   // Precondition: opcode() == HloOpcode::kFusion
    889   int64 fused_instruction_count() const;
    890 
    891   // Returns the fused parameter instruction in this fusion instruction
    892   // corresponding to the given parameter number.
    893   //
    894   // Precondition: opcode() == HloOpcode::kFusion
    895   HloInstruction* fused_parameter(int64 parameter_number) const;
    896 
    897   // Returns the vector of fused parameters inside this fusion instruction.
    898   //
    899   // Precondition: opcode() == HloOpcode::kFusion
    900   const std::vector<HloInstruction*>& fused_parameters() const;
    901 
    902   // Returns true if this instruction is a fusion instruction that generates
    903   // multiple outputs.
    904   const bool IsMultiOutputFusion() const {
    905     return opcode() == HloOpcode::kFusion &&
    906            fused_expression_root()->opcode() == HloOpcode::kTuple;
    907   }
    908 
    909   FusionKind fusion_kind() const {
    910     CHECK_EQ(HloOpcode::kFusion, opcode_);
    911     return fusion_kind_;
    912   }
    913 
    914   void set_fusion_kind(FusionKind kind) {
    915     CHECK_EQ(HloOpcode::kFusion, opcode_);
    916     fusion_kind_ = kind;
    917   }
    918 
    919   // Returns the sharding applied to this operator.
    920   // REQUIRES: has_sharding() is true.
    921   const HloSharding& sharding() const {
    922     CHECK(has_sharding());
    923     return *sharding_;
    924   }
    925   // Returns the sharding applied to this operator, or default_ if none exists.
    926   const HloSharding& sharding_or_default(const HloSharding& default_) const {
    927     return sharding_ ? *sharding_ : default_;
    928   }
    929   // Sets the sharding of this operator. Should only be called by HloModule or
    930   // HloComputation methods.
    931   void set_sharding(const HloSharding& sharding) {
    932     sharding_ = MakeUnique<HloSharding>(sharding);
    933   }
    934   // Remove any sharding from this operator.
    935   void clear_sharding() { sharding_ = nullptr; }
    936   // Return true if this operator has a sharding assigned.
    937   bool has_sharding() const { return sharding_ != nullptr; }
    938 
    939   // Adds a new operand the fusion instruction.
    940   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
    941 
    942   // Merges the fused instructions from 'instruction_to_merge' into the
    943   // fused instruction set of 'this', updating operands as necessary.
    944   //
    945   // Precondition: opcode() == HloOpcode::kFusion
    946   // Predondition: 'instruction_to_merge' must be an operand of 'this'.
    947   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
    948 
    949   // Merges the fused instructions from instruction_to_merge into the fused
    950   // instruction set of 'this' and generates multioutput fusion instructions.
    951   // All the users of instruction_to_merge will be redirected to 'this'
    952   // instruction. instruction_to_merge will be removed from its parent
    953   // computation.
    954   //
    955   // Precondition: opcode() == HloOpcode::kFusion
    956   void MergeFusionInstructionIntoMultiOutput(
    957       HloInstruction* instruction_to_merge);
    958 
    959   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
    960   // is cloned and the clone is placed in the fusion
    961   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
    962   // than moved to cleanly handle the case where the instruction has a use
    963   // outside the fusion instruction. Moving such an instruction into a fusion
    964   // instruction would violate the single-result invariant of HLO instructions
    965   // and significantly complicate code generation.
    966   //
    967   // Precondition: this->opcode() == HloOpcode::kFusion
    968   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
    969     return FuseInstructionInternal(instruction_to_fuse);
    970   }
    971 
    972   // Fuses the given instruction in this fusion instruction and generate
    973   // multioutput fusion instruction. A clone of the instruction_to_fuse will
    974   // be part of the output of fusion instructions. The users of
    975   // instruction_to_fuse will be redirected to this fusion instructions.
    976   // instruction_to_fuse will be removed from its parent computation.
    977   //
    978   // Precondition: this->opcode() == HloOpcode::kFusion
    979   HloInstruction* FuseInstructionIntoMultiOutput(
    980       HloInstruction* instruction_to_fuse) {
    981     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
    982   }
    983 
    984   // Returns the start index in the given dimension for a slice node.
    985   //
    986   // Precondition: opcode() == HloOpcode::kSlice
    987   int64 slice_starts(int64 dimension) const {
    988     CHECK_EQ(HloOpcode::kSlice, opcode_);
    989     return slice_starts_[dimension];
    990   }
    991   const std::vector<int64>& slice_starts() const { return slice_starts_; }
    992 
    993   // Returns the (exclusive) limit index in the given dimension for a slice
    994   // node.
    995   //
    996   // Precondition: opcode() == HloOpcode::kSlice
    997   int64 slice_limits(int64 dimension) const {
    998     CHECK_EQ(HloOpcode::kSlice, opcode_);
    999     return slice_limits_[dimension];
   1000   }
   1001   const std::vector<int64>& slice_limits() const {
   1002     CHECK_EQ(HloOpcode::kSlice, opcode_);
   1003     return slice_limits_;
   1004   }
   1005 
   1006   // Returns the stride in the given dimension for a slice node.
   1007   //
   1008   // Precondition: opcode() == HloOpcode::kSlice
   1009   int64 slice_strides(int64 dimension) const {
   1010     CHECK_EQ(HloOpcode::kSlice, opcode_);
   1011     return slice_strides_[dimension];
   1012   }
   1013   const std::vector<int64>& slice_strides() const { return slice_strides_; }
   1014 
   1015   // Returns the flag that describes whether a slice must be lowered into an
   1016   // offset into the original operand.
   1017   bool IsInPlaceSlice() const { return is_in_place_slice_; }
   1018 
   1019   // Sets and returns the flag that describes whether a slice must be lowered
   1020   // into an offset into the original operand.
   1021   bool SetIsInPlaceSlice(bool value) {
   1022     is_in_place_slice_ = value;
   1023     return value;
   1024   }
   1025 
   1026   // Returns the size of the slice in the given dimension for a dynamic
   1027   // slice node.
   1028   //
   1029   // Precondition: opcode() == HloOpcode::kDynamicSlice
   1030   int64 slice_sizes(int64 dimension) const {
   1031     CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
   1032     return dynamic_slice_sizes_[dimension];
   1033   }
   1034   const std::vector<int64>& dynamic_slice_sizes() const {
   1035     CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
   1036     return dynamic_slice_sizes_;
   1037   }
   1038 
   1039   // Returns the number of exponent bits for a reduce-precision node.
   1040   //
   1041   // Precondition: opcode() == HloOpcode::kReducePrecision
   1042   int32 exponent_bits() const {
   1043     CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
   1044     return exponent_bits_;
   1045   }
   1046 
   1047   // Returns the number of mantissa bits for a reduce-precision node.
   1048   //
   1049   // Precondition: opcode() == HloOpcode::kReducePrecision
   1050   int32 mantissa_bits() const {
   1051     CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
   1052     return mantissa_bits_;
   1053   }
   1054 
   1055   // Returns data on the window in a windowed operation such as
   1056   // convolution.
   1057   const Window& window() const {
   1058     CHECK(window_ != nullptr);
   1059     return *window_;
   1060   }
   1061 
   1062   // Sets the window data in a windowed operation such as convolution.
   1063   void set_window(const Window& window) {
   1064     window_ = MakeUnique<Window>(window);
   1065   }
   1066 
   1067   // Returns the padding configuration for a pad node.
   1068   //
   1069   // Precondition: opcode() == HloOpcode::kPad
   1070   const PaddingConfig& padding_config() const {
   1071     CHECK(padding_config_ != nullptr);
   1072     return *padding_config_;
   1073   }
   1074 
   1075   // Returns data on the dimension numbers used for a convolution operation,
   1076   // which may be a kConvolution instruction or a kCustomCall that implements a
   1077   // convolution.
   1078   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
   1079     CHECK(convolution_dimension_numbers_ != nullptr);
   1080     return *convolution_dimension_numbers_;
   1081   }
   1082 
   1083   // Sets the convolution dimension numbers on this instruction.  In general you
   1084   // shouldn't need to call this; instead, specify the convolution dimension
   1085   // numbers when you create the instruction.
   1086   void set_convolution_dimension_numbers(
   1087       const ConvolutionDimensionNumbers& dnums) {
   1088     convolution_dimension_numbers_ =
   1089         MakeUnique<ConvolutionDimensionNumbers>(dnums);
   1090   }
   1091 
   1092   FftType fft_type() const {
   1093     CHECK_EQ(HloOpcode::kFft, opcode_);
   1094     return fft_type_;
   1095   }
   1096 
   1097   const std::vector<int64>& fft_length() const {
   1098     CHECK_EQ(HloOpcode::kFft, opcode_);
   1099     return fft_length_;
   1100   }
   1101 
   1102   // Returns the dump string of the convolution dimension numbers.
   1103   string ConvolutionDimensionNumbersToString() const;
   1104 
   1105   // Returns data on the dimension numbers used for a dot operation.
   1106   const DotDimensionNumbers& dot_dimension_numbers() const {
   1107     CHECK(dot_dimension_numbers_ != nullptr);
   1108     return *dot_dimension_numbers_;
   1109   }
   1110 
   1111   // Returns the dump string of the dot dimension numbers.
   1112   string DotDimensionNumbersToString() const;
   1113 
   1114   const GatherDimensionNumbers& gather_dimension_numbers() const {
   1115     CHECK(gather_dimension_numbers_ != nullptr);
   1116     return *gather_dimension_numbers_;
   1117   }
   1118 
   1119   tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
   1120     CHECK_EQ(opcode(), HloOpcode::kGather);
   1121     return gather_window_bounds_;
   1122   }
   1123 
   1124   // Returns the dump string of the gather dimension numbers.
   1125   string GatherDimensionNumbersToString() const;
   1126 
   1127   // Returns the random distribution for this rng node.
   1128   //
   1129   // Precondition: opcode() == HloOpcode::kRng
   1130   RandomDistribution random_distribution() const;
   1131 
   1132   // Clones the HLO instruction. The clone will have the same opcode, shape, and
   1133   // operands. After creation the clone has no uses. "this" (the instruction
   1134   // cloned from) is not changed. Suffix is the string to append to the name of
   1135   // the instruction to form the name of the cloned instruction.
   1136   // If the module pointer is not nullptr, it will be the module where
   1137   // the cloned computations will be added to (in order to support deep
   1138   // cloning).
   1139   std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
   1140                                         HloModule* module = nullptr) const;
   1141 
   1142   // Clones the HLO instruction as above but with new shape and operands.
   1143   // If the module pointer is not nullptr, it will be the module where
   1144   // the cloned computations will be added to (in order to support deep
   1145   // cloning).
   1146   std::unique_ptr<HloInstruction> CloneWithNewOperands(
   1147       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1148       HloModule* module = nullptr) const;
   1149 
   1150   // Returns the computations this instruction directly calls (if any).
   1151   const std::vector<HloComputation*>& called_computations() const {
   1152     return called_computations_;
   1153   }
   1154 
   1155   // Replaces all called computations based on a map function. This is needed
   1156   // when we clone hlo_computations and want to let the instructions to point
   1157   // to the newly cloned nodes.
   1158   void ReplaceCalledComputations(
   1159       std::function<HloComputation*(HloComputation*)> map_function) {
   1160     for (int64 i = 0; i < called_computations_.size(); ++i) {
   1161       called_computations_[i] = map_function(called_computations_[i]);
   1162     }
   1163   }
   1164 
   1165   // Clears out the called computations.
   1166   //
   1167   // This is, in particular, necessary when inlining function bodies into their
   1168   // caller. If there were side-effecting operations in the called computations,
   1169   // the call itself is considered side-effecting and thus cannot be removed. By
   1170   // clearing out the computations, we reflect the fact that all side-effecting
   1171   // properties have been reflected in the caller, and make the call HLO
   1172   // removable.
   1173   void ClearCalledComputations() { called_computations_.clear(); }
   1174 
   1175   // Returns true if this instruction performs an elementwise operation on
   1176   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
   1177   // after performing necessary implicit broadcast
   1178   // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
   1179   // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
   1180   // the element at {i_0,i_1,...,i_n}.
   1181   //
   1182   // Note on performance: when this instruction is kFusion, this method, in the
   1183   // worst case, scans all fused instructions. We could speed this up by
   1184   // caching.
   1185   bool IsElementwiseOnOperand(int64 operand_idx) const;
   1186 
   1187   // Returns true if this instruction is elementwise on all its operands.
   1188   bool IsElementwise() const;
   1189 
   1190   // Returns true if this elementwise instruction implicitly broadcasts operand
   1191   // `operand_idx`.
   1192   //
   1193   // Precondition: this instruction should be an elementwise operation.
   1194   bool ImplicitlyBroadcastsOperand(int64 operand_idx) const;
   1195 
   1196   // Returns true if this instruction is binary and elementwise.
   1197   bool IsElementwiseBinary() const;
   1198 
   1199   // Returns whether this instruction may reuse elements of its `i`th operand.
   1200   bool ReusesOperandElements(int64 i) const {
   1201     return OperandElementUse(i) == UseKind::kReuse;
   1202   }
   1203 
   1204   // Returns the indices that the given operand appear in the operand list of
   1205   // this instruction. Note that an instruction can use the same operand
   1206   // multiple times.
   1207   std::vector<int64> OperandIndices(const HloInstruction* operand) const;
   1208 
   1209   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
   1210   // this reshape merely inserts or deletes 1-sized dimensions, return the input
   1211   // indices of the deleted dimensions and the output indices of the inserted
   1212   // dimensions.
   1213   //
   1214   // Precondition: this op must be a reshape.
   1215   std::tuple<bool, std::vector<int64>, std::vector<int64>>
   1216   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
   1217 
   1218   // Gets/sets the string identifier for this instruction.
   1219   const string& name() const { return name_; }
   1220   void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
   1221 
   1222   // Use the given NameUniquer to select a unique name for the instruction based
   1223   // on the instruction's existing name.
   1224   void UniquifyName(NameUniquer* name_uniquer);
   1225 
   1226   // Set the unique id for this instruction to "id"
   1227   void SetUniqueId(int id) {
   1228     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
   1229     CHECK_GE(id, 0);
   1230     unique_id_ = id;
   1231   }
   1232 
   1233   // Return the unique ID assigned to this node via SetUniqueId (or -1
   1234   // if no id has been assigned yet).
   1235   int unique_id() const { return unique_id_; }
   1236 
   1237   // Sets the debug metadata for this instruction.
   1238   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
   1239   const OpMetadata& metadata() const { return metadata_; }
   1240 
   1241   // Set/get the computation containing this instruction. set_parent should only
   1242   // be called by HloComputation methods which add/remove instructions to
   1243   // computations.
   1244   void set_parent(HloComputation* computation) { parent_ = computation; }
   1245   const HloComputation* parent() const { return parent_; }
   1246   HloComputation* parent() { return parent_; }
   1247 
   1248   // Returns the module for this instruction.
   1249   HloModule* GetModule() const;
   1250 
   1251   // Returns whether we could assign input and output layouts to this
   1252   // instruction to make it a bitcast.
   1253   bool CouldBeBitcast() const;
   1254 
   1255   // Get/Set the number of partitions per outer dimension (in order, starting
   1256   // with outer-most dimension first). Currently used by the parallel cpu
   1257   // backend to partition HLOs into parallel tasks.
   1258   // TODO(b/62783254) Replace these methods with a more general way to
   1259   // annotate HLOs with backend-specific information.
   1260   const std::vector<int64>& outer_dimension_partitions() const {
   1261     return outer_dimension_partitions_;
   1262   }
   1263   void set_outer_dimension_partitions(
   1264       const std::vector<int64>& outer_dimension_partitions);
   1265 
   1266   // Change the layout for an Constant Hlo instruction to match new_layout.  For
   1267   // tuple shaped constants shape_index is the path to the internal array
   1268   // subshape whose layout needs to be changed.
   1269   void RelayoutConstant(const Layout& new_layout,
   1270                         const ShapeIndex& shape_index = {});
   1271 
   1272  private:
   1273   enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
   1274 
   1275   // Helper class for computing OperandElementUse for kFusion.
   1276   class FusionReusesParamElements;
   1277 
   1278   // See comments on Identical().
   1279   // eq_shapes() is used to check shapes for equality, and would normally be
   1280   // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
   1281   // whether we want a layout-sensitive check or not.
   1282   bool IdenticalSlowPath(
   1283       const HloInstruction& other,
   1284       const std::function<bool(const HloComputation*, const HloComputation*)>&
   1285           eq_computations,
   1286       const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
   1287 
   1288   // Creates an n-ary elementwise operation.
   1289   static std::unique_ptr<HloInstruction> CreateNary(
   1290       const Shape& shape, HloOpcode opcode,
   1291       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
   1292 
   1293   // Appends operand to the list of operands and adds this instruction as a user
   1294   // of the operand.
   1295   void AppendOperand(HloInstruction* operand);
   1296 
   1297   // Adds a user for this instruction.
   1298   void AddUser(HloInstruction* user);
   1299 
   1300   // Removes a user for this instruction.
   1301   void RemoveUser(HloInstruction* user);
   1302 
   1303   // Internal constructor for a given opcode/shape, other fields must be filled
   1304   // by factory methods.
   1305   HloInstruction(HloOpcode opcode, const Shape& shape);
   1306 
   1307   // Fuses the given instruction into this fusion instruction. When add_output
   1308   // is false (which is the default), instruction_to_fuse is cloned and the
   1309   // clone is placed in the fusion instruction. instruction_to_fuse is
   1310   // unchanged.
   1311   //
   1312   // When add_output is true, a clone of the instruction_to_fuse will be part
   1313   // of the output of fusion instructions. The users of instruction_to_fuse
   1314   // will be redirected to this fusion instructions. instruction_to_fuse will
   1315   // be removed from its parent computation.
   1316   //
   1317   // Precondition: this->opcode() == HloOpcode::kFusion
   1318   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
   1319                                           bool add_output = false);
   1320 
   1321   // Clones the given instruction_to_fuse and insert the clone into this fusion
   1322   // instruction. If add_output is true, a clone of instruction_to_fuse will
   1323   // be in the output of the this fusion instruction (part of the tuple of the
   1324   // fusion root).
   1325   //
   1326   // Precondition: opcode() == HloOpcode::kFusion
   1327   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
   1328                                        bool add_output = false);
   1329 
   1330   // Clones a fusion instruction with a new shape and operands.
   1331   std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
   1332       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1333       HloModule* module = nullptr) const;
   1334 
   1335   // Returns true if this instruction can legally have the dimensions field
   1336   // set. Used for checking precondition of dimensions field accessors.
   1337   bool CanHaveDimensionsField() const;
   1338 
   1339   // Returns how this instruction uses elements of its `i`th operand.
   1340   UseKind OperandElementUse(int64 i) const;
   1341 
   1342   int unique_id_;  // Unique to this HloInstruction within a HloModule
   1343 
   1344   // Opcode for this instruction.
   1345   HloOpcode opcode_;
   1346 
   1347   // Instruction operands.
   1348   InstructionVector operands_;
   1349 
   1350   // The set of control predecessors of this instruction.
   1351   std::vector<HloInstruction*> control_predecessors_;
   1352 
   1353   // The users of this instruction. Users are HLOs where this instruction is an
   1354   // operand. The vector users_ and the set user_set_ contain identical
   1355   // members. The set enables fast membership testing and the vector enables
   1356   // fast, stable iteration.
   1357   std::vector<HloInstruction*> users_;
   1358   std::unordered_set<const HloInstruction*> user_set_;
   1359 
   1360   // The set of control successors of this instruction.
   1361   std::vector<HloInstruction*> control_successors_;
   1362 
   1363   // The computation in which this instruction is contained.
   1364   HloComputation* parent_ = nullptr;
   1365 
   1366   // Shape of outfeed request.
   1367   Shape outfeed_shape_;
   1368 
   1369   // Result shape of this instruction.
   1370   Shape shape_;
   1371 
   1372   // Literal, only present for kConstant.
   1373   std::unique_ptr<Literal> literal_;
   1374 
   1375   // Constant index, only present for kGetTupleElement.
   1376   int64 tuple_index_ = -1;
   1377 
   1378   // Dimensions present for some operations that require reshaping or
   1379   // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
   1380   std::vector<int64> dimensions_;
   1381 
   1382   // Describes the window in a windowed operation such as convolution.
   1383   std::unique_ptr<Window> window_;
   1384 
   1385   // Describes the dimension numbers used for a convolution.
   1386   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
   1387 
   1388   // Describes the dimension numbers used for a dot.
   1389   std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
   1390 
   1391   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
   1392   std::vector<int64> gather_window_bounds_;
   1393 
   1394   // Describes FFT type for an FFT instruction.
   1395   FftType fft_type_ = FftType::FFT;
   1396 
   1397   // Indicates the FFT length for an FFT instruction.
   1398   std::vector<int64> fft_length_;
   1399 
   1400   // Describes the [begin, end) index range for a slice.
   1401   std::vector<int64> slice_starts_;
   1402   std::vector<int64> slice_limits_;
   1403   std::vector<int64> slice_strides_;
   1404 
   1405   // Describes whether the slice can be lowered to an offset into the operand.
   1406   bool is_in_place_slice_ = false;
   1407 
   1408   // The bit sizes for a reduce-precision operation.
   1409   int32 exponent_bits_ = 0;
   1410   int32 mantissa_bits_ = 0;
   1411 
   1412   // Describes the [start, start + size) range size for a dynamic slice
   1413   // ('start' is specified dynamically in the second operand of the operation).
   1414   std::vector<int64> dynamic_slice_sizes_;
   1415 
   1416   // The padding configuration that describes the edge padding and interior
   1417   // padding of this pad instruction. Only set for pad instructions.
   1418   std::unique_ptr<PaddingConfig> padding_config_;
   1419 
   1420   // The type of the fusion. Used by kFusion only.
   1421   FusionKind fusion_kind_;
   1422 
   1423   // The sharding, if one exists.
   1424   std::unique_ptr<HloSharding> sharding_;
   1425 
   1426   // For parameter instructions this field holds the parameter number.
   1427   int64 parameter_number_ = 0;
   1428 
   1429   // Name of a global symbol to call, only present for kCustomCall.
   1430   string custom_call_target_;
   1431 
   1432   // Name to use for host send/recv channels, only present for kHostCompute.
   1433   string channel_name_;
   1434 
   1435   // Estimate of the duration of a host computation in nanoseconds.
   1436   int64 cost_estimate_ns_;
   1437 
   1438   // Computations called by this instruction.
   1439   std::vector<HloComputation*> called_computations_;
   1440 
   1441   // Indices of computations in called_computations_ for instructions which call
   1442   // multiple computations.
   1443   enum {
   1444     // kWhile computations.
   1445     kBodyComputationIndex = 0,
   1446     kConditionComputationIndex = 1,
   1447 
   1448     // kSelectAndScatter computations.
   1449     kSelectComputationIndex = 0,
   1450     kScatterComputationIndex = 1,
   1451 
   1452     // kConditional computations.
   1453     kTrueComputationIndex = 0,
   1454     kFalseComputationIndex = 1,
   1455   };
   1456 
   1457   // Outfeed configuration information, only present for kOutfeed.
   1458   string outfeed_config_;
   1459 
   1460   // A trace instruction that consumes this instruction.
   1461   //
   1462   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
   1463   // an operand.
   1464   HloInstruction* trace_instruction_ = nullptr;
   1465 
   1466   // The distribution requested for random number generation.
   1467   // Only present for kRng.
   1468   RandomDistribution distribution_;
   1469 
   1470   // A small float number added to the variance to avoid divide-by-zero error.
   1471   // Only present for kBatchNormTraining.
   1472   float epsilon_ = 0.0f;
   1473 
   1474   // An integer value representing the index of the feature dimension.
   1475   // Only present for kBatchNormTraining.
   1476   int64 feature_index_ = -1;
   1477 
   1478   // Represents a unique identifier for each Send/Recv instruction pair.
   1479   // Only present for kSend or kRecv.
   1480   int64 channel_id_ = -1;
   1481 
   1482   // The string representation of the infeed configuration.
   1483   string infeed_config_;
   1484 
   1485   // String identifier for instruction.
   1486   string name_;
   1487 
   1488   // Metadata for debugging.
   1489   OpMetadata metadata_;
   1490 
   1491   // The number of partitions per outer dimension (listed in order from
   1492   // outer-most dimension first).
   1493   std::vector<int64> outer_dimension_partitions_;
   1494 
   1495   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
   1496 };
   1497 
   1498 string ToString(HloInstruction::FusionKind kind);
   1499 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
   1500     const string& kind_name);
   1501 
   1502 // Custom (de)stringification functions for protos that live inside
   1503 // HloInstruction.
   1504 string PaddingConfigToString(const PaddingConfig& padding);
   1505 string OpMetadataToString(const OpMetadata& metadata);
   1506 string RandomDistributionToString(const RandomDistribution& distribution);
   1507 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
   1508 
   1509 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
   1510 
   1511 // Map classes that guarantee a deterministic iteration order when the key is
   1512 // an HloInstruction* or a const HloInstruction*.
   1513 // To make the iteration order over the map deterministic, the comparator
   1514 // should not be using the pointer values, but rather an intrinsic property of
   1515 // the hlo.
   1516 //
   1517 // Note that this cannot be used for HLO instructions across multiple modules
   1518 // since the id of HLO instructions are only unique within each HLO module.
   1519 struct HloPtrComparator {
   1520   bool operator()(const HloInstruction* const& lhs,
   1521                   const HloInstruction* const& rhs) const {
   1522     return lhs->unique_id() < rhs->unique_id();
   1523   }
   1524 };
   1525 
   1526 template <typename ValueT>
   1527 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
   1528 
   1529 template <typename ValueT>
   1530 using ConstHloInstructionMap =
   1531     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
   1532 
   1533 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
   1534 using ConstHloInstructionSet =
   1535     std::set<const HloInstruction*, HloPtrComparator>;
   1536 
   1537 }  // namespace xla
   1538 
   1539 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
   1540