Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // HLO instructions are in DAG form and represent the computations that the user
     17 // has built up via the XLA service interface. They are ultimately lowered
     18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
     19 // form; e.g. see DfsHloVisitor.
     20 
     21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
     22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
     23 
     24 #include <functional>
     25 #include <iosfwd>
     26 #include <list>
     27 #include <memory>
     28 #include <set>
     29 #include <string>
     30 #include <tuple>
     31 #include <vector>
     32 
     33 #include "absl/container/flat_hash_map.h"
     34 #include "absl/container/flat_hash_set.h"
     35 #include "absl/container/inlined_vector.h"
     36 #include "absl/memory/memory.h"
     37 #include "absl/strings/str_cat.h"
     38 #include "absl/strings/string_view.h"
     39 #include "absl/types/span.h"
     40 #include "tensorflow/compiler/xla/comparison_util.h"
     41 #include "tensorflow/compiler/xla/iterator_util.h"
     42 #include "tensorflow/compiler/xla/literal.h"
     43 #include "tensorflow/compiler/xla/map_util.h"
     44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     45 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
     47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
     48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     49 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
     50 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     51 #include "tensorflow/compiler/xla/shape_tree.h"
     52 #include "tensorflow/compiler/xla/types.h"
     53 #include "tensorflow/compiler/xla/xla_data.pb.h"
     54 #include "tensorflow/core/lib/core/status.h"
     55 #include "tensorflow/core/lib/gtl/iterator_range.h"
     56 #include "tensorflow/core/platform/logging.h"
     57 #include "tensorflow/core/platform/macros.h"
     58 #include "tensorflow/core/platform/protobuf.h"
     59 #include "tensorflow/core/platform/types.h"
     60 
     61 namespace xla {
     62 
     63 class HloComputation;
     64 class HloModule;
     65 
     66 // A bunch of switches that control how the hlo text should be printed.
     67 class HloPrintOptions {
     68  public:
     69   enum class PrintSubcomputationMode {
     70     kOff,         // Do not print anything about subcomputations.
     71     kNameOnly,    // Only print the name of subcomputations.
     72     kFullBodies,  // Print the full bodies of subcomputations.
     73   };
     74 
     75   // Constructs the default print options: don't print large constants, don't
     76   // compact operands, no indentation.
     77   HloPrintOptions()
     78       : print_large_constants_(false),
     79         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
     80         print_metadata_(true),
     81         print_backend_config_(true),
     82         compact_operands_(false),
     83         print_operand_shape_(true),
     84         print_operand_names_(true),
     85         print_program_shape_(true),
     86         print_percent_(true),
     87         print_control_dependencies_(true),
     88         canonicalize_instruction_names_(false),
     89         indent_amount_(0),
     90         is_in_nested_computation_(false) {}
     91 
     92   static HloPrintOptions ShortParsable() {
     93     return HloPrintOptions()
     94         .set_print_large_constants(true)
     95         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
     96         .set_print_metadata(false)
     97         .set_print_backend_config(false)
     98         .set_print_operand_shape(false)
     99         .set_print_program_shape(false)
    100         .set_print_percent(false)
    101         .set_print_control_dependencies(false);
    102   }
    103 
    104   // Options to produce the canonical string representing an isomorphic
    105   // computation graph.
    106   static HloPrintOptions Canonical() {
    107     return HloPrintOptions()
    108         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
    109         .set_print_metadata(false)
    110         .set_print_backend_config(false)
    111         .set_compact_operands(true)
    112         .set_print_operand_names(false)
    113         .set_print_operand_shape(true)
    114         .set_print_program_shape(false)
    115         .set_print_percent(false)
    116         .set_print_control_dependencies(false)
    117         .set_canonicalize_instruction_names(true);
    118   }
    119 
    120   // If true, large constants will be printed out.
    121   HloPrintOptions& set_print_large_constants(bool value) {
    122     print_large_constants_ = value;
    123     return *this;
    124   }
    125 
    126   HloPrintOptions& set_print_subcomputation_mode(
    127       PrintSubcomputationMode value) {
    128     print_subcomputation_mode_ = value;
    129     return *this;
    130   }
    131 
    132   // If true, metadata will be printed.
    133   HloPrintOptions& set_print_metadata(bool value) {
    134     print_metadata_ = value;
    135     return *this;
    136   }
    137 
    138   // If true, backend_config will be printed.
    139   HloPrintOptions& set_print_backend_config(bool value) {
    140     print_backend_config_ = value;
    141     return *this;
    142   }
    143 
    144   // If true, operands' shapes will be printed.
    145   HloPrintOptions& set_print_operand_shape(bool value) {
    146     print_operand_shape_ = value;
    147     return *this;
    148   }
    149 
    150   // If true, the operand names will be printed.
    151   HloPrintOptions& set_print_operand_names(bool value) {
    152     print_operand_names_ = value;
    153     return *this;
    154   }
    155 
    156   // If true, program shape of hlo computations will be printed.
    157   HloPrintOptions& set_print_program_shape(bool value) {
    158     print_program_shape_ = value;
    159     return *this;
    160   }
    161 
    162   // If true, names will be printed with prefix '%'.
    163   HloPrintOptions& set_print_percent(bool value) {
    164     print_percent_ = value;
    165     return *this;
    166   }
    167 
    168   // If true, control dependencies will be printed.
    169   HloPrintOptions& set_print_control_dependencies(bool value) {
    170     print_control_dependencies_ = value;
    171     return *this;
    172   }
    173 
    174   // If true, only a part of operands will be printed out (note that in this
    175   // case the text will not be parsable).
    176   HloPrintOptions& set_compact_operands(bool value) {
    177     compact_operands_ = value;
    178     return *this;
    179   }
    180 
    181   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
    182   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
    183   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
    184     canonicalize_instruction_names_ = value;
    185     return *this;
    186   }
    187 
    188   // The indent of the hlo text block.
    189   HloPrintOptions& set_indent_amount(int value) {
    190     indent_amount_ = value;
    191     return *this;
    192   }
    193 
    194   // If true, indicates the instruction being printed is inside a nested
    195   // computation.
    196   HloPrintOptions& set_is_in_nested_computation(bool value) {
    197     is_in_nested_computation_ = value;
    198     return *this;
    199   }
    200 
    201   bool print_large_constants() const { return print_large_constants_; }
    202   PrintSubcomputationMode print_subcomputation_mode() const {
    203     return print_subcomputation_mode_;
    204   }
    205   bool print_metadata() const { return print_metadata_; }
    206   bool print_backend_config() const { return print_backend_config_; }
    207   bool compact_operands() const { return compact_operands_; }
    208   bool print_operand_shape() const { return print_operand_shape_; }
    209   bool print_operand_names() const { return print_operand_names_; }
    210   bool print_program_shape() const { return print_program_shape_; }
    211   bool print_percent() const { return print_percent_; }
    212   bool print_control_dependencies() const {
    213     return print_control_dependencies_;
    214   }
    215   bool canonicalize_instruction_names() const {
    216     return canonicalize_instruction_names_;
    217   }
    218   int indent_amount() const { return indent_amount_; }
    219   int is_in_nested_computation() const { return is_in_nested_computation_; }
    220 
    221  private:
    222   bool print_large_constants_;
    223   PrintSubcomputationMode print_subcomputation_mode_;
    224   bool print_metadata_;
    225   bool print_backend_config_;
    226   bool compact_operands_;
    227   bool print_operand_shape_;
    228   bool print_operand_names_;
    229   bool print_program_shape_;
    230   bool print_percent_;
    231   bool print_control_dependencies_;
    232   bool canonicalize_instruction_names_;
    233   int indent_amount_;
    234   bool is_in_nested_computation_;
    235 };
    236 
    237 // For canonical string output, we need to have a canonical way to rename
    238 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
    239 // where <xxx> is an index starting from 0.
    240 class CanonicalNameMap {
    241  public:
    242   CanonicalNameMap() : index(0) {}
    243 
    244   string LookupOrInsert(const string& old_name) {
    245     auto iter = canonical_name_map.find(old_name);
    246     if (iter != canonical_name_map.end()) {
    247       return iter->second;
    248     }
    249 
    250     string new_name = absl::StrCat("tmp_", index++);
    251     canonical_name_map[old_name] = new_name;
    252     return new_name;
    253   }
    254   void Clear() {
    255     canonical_name_map.clear();
    256     index = 0;
    257   }
    258 
    259  private:
    260   int64 index;
    261   absl::flat_hash_map<string, string> canonical_name_map;
    262 };
    263 
    264 // HLO instructions are the atomic unit of the high-level compiler's IR.
    265 //
    266 // HloInstructions live inside of an HloComputation, which is analogous to a
    267 // function in other programming languages.  Nodes have no total order within
    268 // their computation.  Instead, they have a partial ordering determined by their
    269 // data and control dependencies.
    270 //
    271 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
    272 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
    273 // control flow.  For example, the kConditional HLO executes one of two possible
    274 // computations, depending on the runtime value of a predicate.
    275 //
    276 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
    277 // values are produced by one HLO and flow into consumers across dependency
    278 // edges.
    279 class HloInstruction {
    280  public:
    281   // A fusion node computes the same value a call to its fusion computation
    282   // would compute.  However, the choice of fusion kind dictates codegen
    283   // strategy for the backend.
    284   //
    285   // To generate code for a kFusion HloInstruction, most backends do something
    286   // like the following:
    287   //
    288   // 1) Identify the "primary" HloInstruction of the fused computation.
    289   // 2) Emit code that does the work of the primary node, creating its inputs
    290   //    and transforming its outputs as specified by the fused computation.
    291   //
    292   // In step (2), the code emitted is usually similar to the code that would be
    293   // emitted for an *unfused* version of the primary node, except that
    294   //
    295   //  - when the primary node reads an element of one of its operands, instead
    296   //    of loading the value from memory, it *computes* the value based on the
    297   //    contents of the fused computation.
    298   //  - when the primary node outputs a value, instead of storing it to memory,
    299   //    it forwards the value to its users, which then perform additional
    300   //    computations before the value is finally stored to memory at the root of
    301   //    the fusion node.
    302   //
    303   // An HloInstruction's FusionKind helps us find the kFusion instruction's
    304   // primary node, and can also affect how we generate code in step (2).
    305   //
    306   //  - kInput: The primary node is the root of the fused instruction.
    307   //
    308   //  - kOutput: The primary node is not the root of the fused instruction.
    309   //    This fusion kind requires that one operand buffer of the fusion
    310   //    instruction be able to alias the output buffer.  This constraint is
    311   //    usually enough to let backends find the primary node unambiguously.
    312   //
    313   //  - kLoop: The primary node is the root of the fused computation, but,
    314   //    unlike in input fusion, we prescribe a specific implementation for
    315   //    codegen.  Rather than generating code that looks like the code we'd emit
    316   //    for an unfused version of the primary/root node, we emit code that
    317   //    generates one element of the root at a time.
    318   //
    319   //  - kCustom: Custom category for backend-specific fusions that don't fit
    320   //    into the above patterns.
    321   //
    322   // Not all backends support all fusion kinds, and given a particular fused
    323   // computation, it's not in general safe to change its fusion kind.  Creation
    324   // of fusion nodes is always backend-specific.
    325   //
    326   // For elementwise ops (e.g. kAdd), most backends would emit a
    327   // one-element-at-a-time implementation for the unfused version, so loop
    328   // fusion and input fusion are probably equivalent if the root node is
    329   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
    330   // implementation might emit something more sophisticated for an unfused or
    331   // input-fusion reduce, but will emit the naive code that reduces one element
    332   // at a time for loop fusion with a reduce as the root.
    333   //
    334   // Another way to think of loop fusion is that it's equivalent to input
    335   // fusion, but where the root node is an implicit identity node, whose
    336   // unfused implementation is "read one element, write one element".
    337   //
    338   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
    339   // input and loop fusion are basically the same thing: There is no reason for
    340   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
    341   // the root of a fusion should be lowered.  In addition, this scheme as
    342   // written doesn't work for multi-output fusion, where the primary node is
    343   // never actually the root (which is a kTuple instruction that gathers the
    344   // multiple outputs of the fusion).
    345   enum class FusionKind {
    346     kLoop,
    347     kInput,
    348     kOutput,
    349     kCustom,
    350   };
    351 
    352   virtual ~HloInstruction();
    353 
    354   // Creates an instruction from the given proto. Arguments:
    355   //
    356   //   proto: the proto to convert from.
    357   //   instruction_map: a map from instruction id to HloInstruction*. This map
    358   //     must contain all operands of the newly constructed instruction.
    359   //   computation_map: a map from computation id to HloComputation*. This map
    360   //     must contain all computations which the newly constructed instruction
    361   //     calls.
    362   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
    363       const HloInstructionProto& proto,
    364       const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
    365       const absl::flat_hash_map<int64, HloComputation*>& computation_map);
    366 
    367   // Creates a parameter-retrieving instruction.
    368   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
    369                                                          const Shape& shape,
    370                                                          const string& name);
    371 
    372   // Creates a literal constant instruction.
    373   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
    374 
    375   // Creates an Iota instruction.
    376   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
    377                                                     int64 iota_dimension);
    378 
    379   // Creates a get tuple element instruction.
    380   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
    381       const Shape& shape, HloInstruction* operand, int64 index);
    382 
    383   // Creates a trace instruction that logs the input operand in the computation.
    384   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
    385                                                      HloInstruction* operand);
    386 
    387   // Creates a random number generation instruction that fills a shape with
    388   // random numbers from a given distribution.
    389   //
    390   // The parameters to the instruction are interpreted as follows:
    391   //
    392   //  - If `distribution` is RNG_UNIFORM, generates a number in range
    393   //    [param0, param1).
    394   //
    395   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
    396   //    with mean `param0` and standard deviation `param1`.
    397   static std::unique_ptr<HloInstruction> CreateRng(
    398       const Shape& shape, RandomDistribution distribution,
    399       absl::Span<HloInstruction* const> parameters);
    400 
    401   // Creates a unary instruction (one operand).
    402   // Precondition: opcode must be a legitimate unary operation.
    403   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
    404                                                      HloOpcode opcode,
    405                                                      HloInstruction* operand);
    406 
    407   // Creates a binary instruction (two operands).
    408   // Precondition: opcode must be a legitimate binary operation.
    409   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
    410                                                       HloOpcode opcode,
    411                                                       HloInstruction* lhs,
    412                                                       HloInstruction* rhs);
    413 
    414   // Creates a ternary instruction (three operands).
    415   // Precondition: opcode must be a legitimate ternary operation.
    416   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
    417                                                        HloOpcode opcode,
    418                                                        HloInstruction* lhs,
    419                                                        HloInstruction* rhs,
    420                                                        HloInstruction* ehs);
    421 
    422   // Creates a variadic instruction (variable number of operands).
    423   // Precondition: opcode must be a legitimate variadic operation.
    424   static std::unique_ptr<HloInstruction> CreateVariadic(
    425       const Shape& shape, HloOpcode opcode,
    426       absl::Span<HloInstruction* const> operands);
    427 
    428   // Creates a map instruction, where the computation (given by the handle) is
    429   // applied element-wise to every element in operands (across the operands,
    430   // at a given index)
    431   static std::unique_ptr<HloInstruction> CreateMap(
    432       const Shape& shape, absl::Span<HloInstruction* const> operands,
    433       HloComputation* map_computation);
    434 
    435   // Creates a convolution op, where rhs is the convolutional filter
    436   // and window describes how the filter is applied to lhs.
    437   static std::unique_ptr<HloInstruction> CreateConvolve(
    438       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    439       int64 feature_group_count, int64 batch_group_count, const Window& window,
    440       const ConvolutionDimensionNumbers& dimension_numbers,
    441       const PrecisionConfig& precision_config);
    442 
    443   // Creates an FFT op, of the type indicated by fft_type.
    444   static std::unique_ptr<HloInstruction> CreateFft(
    445       const Shape& shape, HloInstruction* operand, FftType fft_type,
    446       absl::Span<const int64> fft_length);
    447 
    448   // Creates a compare op, performing the comparison specified in direction.
    449   static std::unique_ptr<HloInstruction> CreateCompare(
    450       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    451       ComparisonDirection direction);
    452 
    453   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
    454       const Shape& shape, HloInstruction* a, HloInstruction* b,
    455       const TriangularSolveOptions& options);
    456 
    457   static std::unique_ptr<HloInstruction> CreateCholesky(
    458       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
    459 
    460   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
    461   // dimensions specified in 'dimension_numbers'.
    462   static std::unique_ptr<HloInstruction> CreateDot(
    463       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    464       const DotDimensionNumbers& dimension_numbers,
    465       const PrecisionConfig& precision_config);
    466 
    467   // Creates a reduce-precision op, where operand is the data to reduce in
    468   // precision, and exponent_bits and mantissa_bits describe the precision to
    469   // reduce it to.
    470   static std::unique_ptr<HloInstruction> CreateReducePrecision(
    471       const Shape& shape, HloInstruction* operand, const int exponent_bits,
    472       const int mantissa_bits);
    473 
    474   // Creates a cross replica reduction op.
    475   //
    476   // `reduction_computation`: the reduction function.
    477   //
    478   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
    479   // empty, all replicas belong to one group in the order of 0 - (n-1).
    480   // Allreduce will be applied within subgroups.
    481   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
    482   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
    483   //
    484   // `all_reduce_id`: for Allreduce nodes from different modules, if they have
    485   // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
    486   // not be applied cross modules.
    487   static std::unique_ptr<HloInstruction> CreateAllReduce(
    488       const Shape& shape, absl::Span<HloInstruction* const> operands,
    489       HloComputation* reduce_computation,
    490       const std::vector<ReplicaGroup>& replica_groups,
    491       absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
    492 
    493   // This op handles the communication of an Alltoall operation. On each core,
    494   // the operands are N ops in the same shape, where N is the number of cores
    495   // participating the Alltoall. Then the N operands are scattered to N cores,
    496   // e.g., the ith operand is sent to the ith core. Then each core gathers the
    497   // received data into a tuple.
    498   //
    499   // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
    500   // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall
    501   // will be applied within subgroups in the specified order. For example,
    502   // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied
    503   // within replica 1, 2, 3, and in the gather phase, the received blocks will
    504   // be concatenated in the order of 1, 2, 3; another Alltoall will be applied
    505   // within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
    506   static std::unique_ptr<HloInstruction> CreateAllToAll(
    507       const Shape& shape, absl::Span<HloInstruction* const> operands,
    508       const std::vector<ReplicaGroup>& replica_groups);
    509 
    510   // Creates a communitation instructions that permutes data cross replicas.
    511   // Data is sent/received according to the (source_replica_id,
    512   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
    513   // target_replica_id in any pair, the output on that replica is a tensor
    514   // consists of 0(s) in `shape`.
    515   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
    516       const Shape& shape, HloInstruction* operand,
    517       const std::vector<std::pair<int64, int64>>& source_target_pairs);
    518 
    519   // Creates an instruction that returns a U32 replica ID.
    520   static std::unique_ptr<HloInstruction> CreateReplicaId();
    521 
    522   // Creates a conversion instruction, where operand is the data to convert and
    523   // shape is the target shape for the conversion.
    524   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
    525                                                        HloInstruction* operand);
    526 
    527   // Creates a bitcast conversion instruction, where operand is the data to
    528   // convert and shape is the target shape for the conversion.
    529   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
    530       const Shape& shape, HloInstruction* operand);
    531 
    532   // Creates an infeed instruction, which reads data of the given shape from the
    533   // Infeed interface of the device. infeed_shape is the shape of the data
    534   // received from the infeed *not* the shape of the infeed instruction which
    535   // is a tuple containing the infeed_shape and the TOKEN.
    536   static std::unique_ptr<HloInstruction> CreateInfeed(
    537       const Shape& infeed_shape, HloInstruction* token_operand,
    538       const string& config);
    539 
    540   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
    541   // shape of the data being outfed *not* the shape of the outfeed instruction
    542   // which is a TOKEN.
    543   static std::unique_ptr<HloInstruction> CreateOutfeed(
    544       const Shape& outfeed_shape, HloInstruction* operand,
    545       HloInstruction* token_operand, absl::string_view outfeed_config);
    546 
    547   // Creates an asynchronous send instruction with the given channel id, which
    548   // initiates sending the operand data to a unique receive instruction in
    549   // another computation that has the same channel id. If is_host_transfer is
    550   // true, then this Send operation transfers data to the host.
    551   static std::unique_ptr<HloInstruction> CreateSend(
    552       HloInstruction* operand, HloInstruction* token, int64 channel_id,
    553       bool is_host_transfer = false);
    554 
    555   // Blocks until data transfer for the Send instruction (operand) is complete.
    556   // The operand must be kSend.
    557   static std::unique_ptr<HloInstruction> CreateSendDone(
    558       HloInstruction* operand, bool is_host_transfer = false);
    559 
    560   // Creates an asynchronous receive instruction with the given channel id,
    561   // which allocates resources to receive data of the given shape from a unique
    562   // send instruction in another computation that has the same channel id.  If
    563   // is_host_transfer is true, then this Send operation transfers data from the
    564   // host.
    565   static std::unique_ptr<HloInstruction> CreateRecv(
    566       const Shape& shape, HloInstruction* token, int64 channel_id,
    567       bool is_host_transfer = false);
    568 
    569   // Blocks until data transfer for the Recv instruction (operand) is complete
    570   // and returns the receive buffer. The operand must be kRecv.
    571   static std::unique_ptr<HloInstruction> CreateRecvDone(
    572       HloInstruction* operand, bool is_host_transfer = false);
    573 
    574   // Creates a slice instruction, where the operand is sliced by the given
    575   // start/limit indices.
    576   static std::unique_ptr<HloInstruction> CreateSlice(
    577       const Shape& shape, HloInstruction* operand,
    578       absl::Span<const int64> start_indices,
    579       absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
    580 
    581   // Creates a slice instruction, where the first operand is sliced by
    582   // start indices specified in the second operand, and by size specified in
    583   // 'slice_sizes'.
    584   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
    585       const Shape& shape, HloInstruction* operand,
    586       absl::Span<HloInstruction* const> start_indices,
    587       absl::Span<const int64> slice_sizes);
    588 
    589   // Creates a dynamic update slice instruction, which updates a slice
    590   // of 'operand' with 'update' and 'start_indices'.
    591   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
    592       const Shape& shape, HloInstruction* operand, HloInstruction* update,
    593       absl::Span<HloInstruction* const> start_indices);
    594 
    595   // Creates a concatenate instruction, where the operands are concatenated on
    596   // the provided dimension.
    597   static std::unique_ptr<HloInstruction> CreateConcatenate(
    598       const Shape& shape, absl::Span<HloInstruction* const> operands,
    599       int64 dimension);
    600 
    601   // Creates a reduce instruction, where the computation (given by the handle)
    602   // is applied successively to every element in operand. For example, let f be
    603   // the function to apply, which takes 2 arguments, an accumulator and the
    604   // current value. Let init be an initial value (which is normally chosen to be
    605   // the identity element for f, e.g. 0 if f is addition).
    606   // Then the reduce HLO will compute:
    607   // f(f(init, value0), value1), ...)
    608   static std::unique_ptr<HloInstruction> CreateReduce(
    609       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
    610       absl::Span<const int64> dimensions_to_reduce,
    611       HloComputation* reduce_computation);
    612 
    613   // A more general, multiple-argument version of the above.
    614   // The function to apply, f, now takes N arguments:
    615   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
    616   // init_valueN], and returns an N-tuple. The performed computation is (for
    617   // commutative and associative f operators) equivalent to:
    618   //
    619   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
    620   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
    621   // ..., inputN.value1)
    622   // ...
    623   static std::unique_ptr<HloInstruction> CreateReduce(
    624       const Shape& shape, absl::Span<HloInstruction* const> operands,
    625       absl::Span<HloInstruction* const> init_values,
    626       absl::Span<const int64> dimensions_to_reduce,
    627       HloComputation* reduce_computation);
    628 
    629   // Creates a reduce-window instruction, where the computation (given
    630   // by the handle) is applied window-wise at each valid window
    631   // position in the operand.
    632   static std::unique_ptr<HloInstruction> CreateReduceWindow(
    633       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
    634       const Window& window, HloComputation* reduce_computation);
    635 
    636   // Creates a batch-norm-training instruction.
    637   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
    638       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    639       HloInstruction* offset, float epsilon, int64 feature_index);
    640 
    641   // Creates a batch-norm-inference instruction.
    642   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
    643       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    644       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
    645       float epsilon, int64 feature_index);
    646 
    647   // Creates a batch-norm-grad instruction.
    648   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
    649       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    650       HloInstruction* mean, HloInstruction* variance,
    651       HloInstruction* grad_output, float epsilon, int64 feature_index);
    652 
    653   // Creates a scatter computation that scatters the `source` array to the
    654   // selected indices of each window.
    655   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
    656       const Shape& shape, HloInstruction* operand, HloComputation* select,
    657       const Window& window, HloInstruction* source, HloInstruction* init_value,
    658       HloComputation* scatter);
    659 
    660   // Creates a broadcast instruction.
    661   static std::unique_ptr<HloInstruction> CreateBroadcast(
    662       const Shape& shape, HloInstruction* operand,
    663       absl::Span<const int64> broadcast_dimensions);
    664 
    665   // Creates a sequence of instructions that performs an explicit broadcast of
    666   // the operand to the target shape.
    667   //
    668   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
    669   // returned as a unique_ptr for API consistency with other factory methods in
    670   // this interface.
    671   //
    672   // TODO(b/72173833) Ideally HloComputations would always be present, and so
    673   // the adder being passed by the caller would not be necessary.
    674   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
    675       const Shape& output_shape, HloInstruction* operand,
    676       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
    677           adder);
    678 
    679   // Creates a pad instruction, where the operand is padded on the edges and
    680   // between the elements with the given padding value.
    681   static std::unique_ptr<HloInstruction> CreatePad(
    682       const Shape& shape, HloInstruction* operand,
    683       HloInstruction* padding_value, const PaddingConfig& padding_config);
    684 
    685   // Creates a reshape instruction, where the operand is flattened row-major
    686   // order and then reshaped to the given result shape.
    687   static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
    688                                                        HloInstruction* operand);
    689 
    690   // Creates a transpose instruction which permutes the operand dimensions.
    691   static std::unique_ptr<HloInstruction> CreateTranspose(
    692       const Shape& shape, HloInstruction* operand,
    693       absl::Span<const int64> dimensions);
    694 
    695   // Creates a n-ary sort op with a 'compare' computation which is used for
    696   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
    697   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
    698   // specific index positions which should be compared, and should return a
    699   // PRED. 'is_stable' specifies whether stable sorting is required.
    700   static std::unique_ptr<HloInstruction> CreateSort(
    701       const Shape& shape, int64 dimension,
    702       absl::Span<HloInstruction* const> operands, HloComputation* compare,
    703       bool is_stable);
    704 
    705   // Creates a while instruction, given a condition computation, a body
    706   // computation, and the initial value for the input of the computations. For
    707   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
    708   // corresponds to the C code below.
    709   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
    710   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
    711                                                      HloComputation* condition,
    712                                                      HloComputation* body,
    713                                                      HloInstruction* init);
    714 
    715   static std::unique_ptr<HloInstruction> CreateConditional(
    716       const Shape& shape, HloInstruction* pred,
    717       HloInstruction* true_computation_arg, HloComputation* true_computation,
    718       HloInstruction* false_computation_arg, HloComputation* false_computation);
    719 
    720   static std::unique_ptr<HloInstruction> CreateConditional(
    721       const Shape& shape, HloInstruction* branch_index,
    722       absl::Span<HloComputation* const> branch_computations,
    723       absl::Span<HloInstruction* const> branch_computation_args);
    724 
    725   static std::unique_ptr<HloInstruction> CreateGather(
    726       const Shape& shape, HloInstruction* operand,
    727       HloInstruction* start_indices,
    728       const GatherDimensionNumbers& gather_dim_numbers,
    729       absl::Span<const int64> slice_sizes);
    730 
    731   static std::unique_ptr<HloInstruction> CreateScatter(
    732       const Shape& shape, HloInstruction* operand,
    733       HloInstruction* scatter_indices, HloInstruction* updates,
    734       HloComputation* update_computation,
    735       const ScatterDimensionNumbers& scatter_dim_numbers);
    736 
    737   // Creates a kDomain instruction which delimits an HLO domain which have
    738   // the provided user and operand side metadata.
    739   static std::unique_ptr<HloInstruction> CreateDomain(
    740       const Shape& shape, HloInstruction* operand,
    741       std::unique_ptr<DomainMetadata> operand_side_metadata,
    742       std::unique_ptr<DomainMetadata> user_side_metadata);
    743 
    744   // Creates a fusion instruction. A fusion instruction contains one or more
    745   // fused instructions forming an expression with a single root
    746   // "fused_root". Additional instructions can be added to the fusion
    747   // instruction with the method FuseInstruction.
    748   static std::unique_ptr<HloInstruction> CreateFusion(
    749       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
    750 
    751   static std::unique_ptr<HloInstruction> CreateFusion(
    752       const Shape& shape, FusionKind fusion_kind,
    753       absl::Span<HloInstruction* const> operands,
    754       HloComputation* fusion_computation);
    755 
    756   // Creates a call instruction that applies the given computation on the given
    757   // operands. "shape" is the resultant shape.
    758   static std::unique_ptr<HloInstruction> CreateCall(
    759       const Shape& shape, absl::Span<HloInstruction* const> operands,
    760       HloComputation* computation);
    761 
    762   // Creates a custom call instruction that applies the given custom call target
    763   // to the given operands. "opaque" can be an arbitrary string with a
    764   // backend-specific interpretation. "shape" is the resultant shape.
    765   static std::unique_ptr<HloInstruction> CreateCustomCall(
    766       const Shape& shape, absl::Span<HloInstruction* const> operands,
    767       absl::string_view custom_call_target, absl::string_view opaque = "");
    768 
    769   // Overload which constrains the layouts of the operand and result. 'shape'
    770   // and 'operand_shapes_with_layout' must have layouts.
    771   // 'operand_shapes_with_layout' must have a compatible element for each
    772   // operand.
    773   static std::unique_ptr<HloInstruction> CreateCustomCall(
    774       const Shape& shape, absl::Span<HloInstruction* const> operands,
    775       absl::string_view custom_call_target,
    776       absl::Span<const Shape> operand_shapes_with_layout,
    777       absl::string_view opaque = "");
    778 
    779   // Creates a tuple instruction with the given elements. This is a convenience
    780   // wrapper around CreateVariadic.
    781   static std::unique_ptr<HloInstruction> CreateTuple(
    782       absl::Span<HloInstruction* const> elements);
    783 
    784   // Creates a reverse instruction, which reverses the order of the elements
    785   // in the specified dimensions.
    786   static std::unique_ptr<HloInstruction> CreateReverse(
    787       const Shape& shape, HloInstruction* operand,
    788       absl::Span<const int64> dimensions);
    789 
    790   // Creates a Afterall instruction used for joining or creating new values of
    791   // token type which thread through side-effecting operations. Operands must
    792   // all be tokens, and there must be at least one operand.
    793   static std::unique_ptr<HloInstruction> CreateAfterAll(
    794       absl::Span<HloInstruction* const> operands);
    795 
    796   // Creates an AfterAll instruction which creates a token type out of thin air
    797   // (no operands). This is a separate method from CreateAfterAll to facility
    798   // the removal of operand-less AfterAll instructions.
    799   // TODO(b/110532604): Remove this capability of creating a token from nothing
    800   // when we plumb a primordial token from the entry computation.
    801   static std::unique_ptr<HloInstruction> CreateToken();
    802 
    803   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
    804       const Shape& shape, HloInstruction* operand, int64 dimension);
    805 
    806   static std::unique_ptr<HloInstruction> CreateAddDependency(
    807       HloInstruction* data_operand, HloInstruction* token_operand);
    808 
    809   // Returns the opcode for this instruction.
    810   HloOpcode opcode() const { return opcode_; }
    811 
    812   // Returns true if this instruction has a side effect, irrespective of whether
    813   // any called computations may contain an instruction with side effects.
    814   bool HasSideEffectNoRecurse() const;
    815 
    816   // Returns true if this instruction has a side effect. An instruction has a
    817   // side effect if it uses certain opcodes or calls a computation with a side
    818   // effect.
    819   bool HasSideEffect() const;
    820 
    821   // Returns the result shape of this instruction.
    822   const Shape& shape() const;
    823 
    824   // Returns the (mutable) result shape of this instruction.
    825   Shape* mutable_shape() { return &shape_; }
    826 
    827   // Returns the ith operand to this instruction.
    828   const HloInstruction* operand(int64 i) const;
    829 
    830   // Returns the ith operand to this instruction.
    831   HloInstruction* mutable_operand(int64 i);
    832 
    833   // Returns the number of operands to this instruction.
    834   int64 operand_count() const { return operands_.size(); }
    835 
    836   // Returns the vector of operands of this instruction.
    837   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
    838   const InstructionVector& operands() const { return operands_; }
    839 
    840   // Returns the vector of unique operands, in the same order they are found
    841   // within the operand vector.
    842   InstructionVector unique_operands() const;
    843 
    844   // Returns the index of 'target' in the operands sequence.
    845   // Precondition: target must be an operand (or a fatal error will occur).
    846   int64 operand_index(const HloInstruction* target) const;
    847 
    848   // Returns the number of users of this instruction.
    849   int64 user_count() const { return users_.size(); }
    850 
    851   // Returns the users of this instruction.
    852   const std::vector<HloInstruction*>& users() const { return users_; }
    853 
    854   // Returns true if this instruction is a user of 'instruction'.
    855   bool IsUserOf(const HloInstruction* instruction) const {
    856     return ContainsKey(instruction->user_set_, this);
    857   }
    858 
    859   // Adds a control dependency from this instruction to the given
    860   // instruction. This instruction becomes a control predecessor of
    861   // 'instruction', and 'instruction' becomes a control successor of this
    862   // instruction. Returns an error status if either of the given instructions
    863   // does not belong to the same computation.
    864   //
    865   // This is used to enforce an additional ordering requirement that is not
    866   // captured by normal data dependencies, such as ordering among Send or Recv
    867   // operations to avoid deadlock.
    868   Status AddControlDependencyTo(HloInstruction* instruction);
    869 
    870   // Removes a previously added control dependency from this instruction to
    871   // 'instruction'.
    872   Status RemoveControlDependencyTo(HloInstruction* instruction);
    873 
    874   // Drops all control predecessors and successors from this HLO instruction.
    875   Status DropAllControlDeps();
    876 
    877   // Copies the control predecessors and successors on this HLO instruction to
    878   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
    879   // this HLO are in the same module.
    880   //
    881   // Depending on the use cases we see in practice, in the future we may
    882   // consider folding the logic here into Clone, CloneWithNewOperands and
    883   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
    884   Status CopyAllControlDepsFrom(const HloInstruction* inst);
    885 
    886   // Returns the set of control predecessors (successors) of this
    887   // instruction. Control predecessors (successors) must execute before (after)
    888   // the current instruction.
    889   const std::vector<HloInstruction*>& control_predecessors() const {
    890     return control_predecessors_;
    891   }
    892   const std::vector<HloInstruction*>& control_successors() const {
    893     return control_successors_;
    894   }
    895 
    896   // Returns true if "other" performs the same computation as this instruction.
    897   bool Identical(
    898       const HloInstruction& other,
    899       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
    900           eq_operands = std::equal_to<const HloInstruction*>(),
    901       const std::function<bool(const HloComputation*, const HloComputation*)>&
    902           eq_computations = std::equal_to<const HloComputation*>(),
    903       bool layout_sensitive = true) const {
    904     // An instruction is always identical to itself.
    905     if (this == &other) {
    906       return true;
    907     }
    908 
    909     // Identical instruction must have the same opcode, shape, and identical
    910     // operands.
    911     if (opcode() != other.opcode()) {
    912       return false;
    913     }
    914     if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
    915                            : ShapeUtil::Compatible(shape(), other.shape()))) {
    916       return false;
    917     }
    918     if (operands().size() != other.operands().size()) {
    919       return false;
    920     }
    921 
    922     // Two AllReduces are Identical if they have the same all_reduce_id.
    923     // Their operands don't have to be Identical.
    924     if (!IsCrossModuleAllReduce()) {
    925       // Use an explicit loop rather than ContainerEquals, because copying
    926       // around std::functions may be too expensive in some cases.
    927       for (size_t i = 0; i < operands().size(); ++i) {
    928         if (!eq_operands(operand(i), other.operand(i))) {
    929           return false;
    930         }
    931       }
    932     }
    933 
    934     if (backend_config_ != other.backend_config_) {
    935       return false;
    936     }
    937 
    938     return IdenticalSlowPath(other, eq_computations);
    939   }
    940 
    941   // Generates a hash value of an HLO instruction. Hash considers
    942   // information on opcode, shape, operands, and typically a root instruction.
    943   // This function returns the same hash value for equivalent HLO instructions,
    944   // with respect to HloInstruction::Identical() method.
    945   //
    946   // Uses hash_operand function to compute hash values of its operands.
    947   // At the very top level, hash_operand should be non-recursive to prevent
    948   // non-termination.
    949   uint64 Hash(
    950       const std::function<uint64(const HloInstruction*)>& hash_operand) const;
    951 
    952   // Calls the above method with non-recursive hash_operand function.
    953   uint64 Hash() const;
    954 
    955   // Returns whether the instruction has a constant operand.
    956   bool HasConstantOperand() const;
    957 
    958   // Replaces the use of this instruction in "user" with "new_producer". Note
    959   // that there might be multiple uses of this instruction in "user"; all will
    960   // be replaced.
    961   //
    962   // If user is a fusion instruction, this function will remove any duplicated
    963   // operands of it which could be created due to this replacement.
    964   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
    965 
    966   // Same as ReplaceUseWith(), but new_producer can have a different shape.
    967   Status ReplaceUseWithDifferentShape(HloInstruction* user,
    968                                       HloInstruction* new_producer);
    969 
    970   // Replaces the specified operand with new_operand. The old and new operands
    971   // must have compatible shapes ignoring floating-point precision.
    972   //
    973   // This function does NOT remove duplicated operands even if this instruction
    974   // is a fusion, so that the existing operand numbers do not change.
    975   Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand);
    976 
    977   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
    978   Status ReplaceOperandWithDifferentShape(int64 operand_num,
    979                                           HloInstruction* new_operand);
    980 
    981   // Replaces all uses of this instruction with the new producer. If
    982   // new_producer is a user of this instruction then new_producer remains a use
    983   // of this instruction to avoid introducing cycles into the graph.
    984   //
    985   // If this instruction is the root of its computation, sets the computation's
    986   // root to new_producer.
    987   //
    988   // The new producer must have a compatible shape ignoring floating-point
    989   // precision.
    990   //
    991   // If a user is a fusion instruction, this function will remove any duplicated
    992   // operands of it which could be created due to this replacement.
    993   Status ReplaceAllUsesWith(HloInstruction* new_producer);
    994 
    995   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
    996   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
    997 
    998   // Performs a postorder DFS visit using this node as the root. If
    999   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
   1000   // complete. If ignore_control_predecessors is true, instructions only
   1001   // reachable via control dependencies will not be visited, and the postorder
   1002   // will not take control dependencies into account. It is as if the control
   1003   // dependencies didn't exist in the graph at all.
   1004   template <typename HloInstructionPtr>
   1005   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
   1006                 bool call_finish_visit = true,
   1007                 bool ignore_control_predecessors = false);
   1008   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
   1009                 bool ignore_control_predecessors = false) const {
   1010     return const_cast<HloInstruction*>(this)->Accept(
   1011         visitor, call_finish_visit, ignore_control_predecessors);
   1012   }
   1013 
   1014   // Same as Accept() above, but the order of operand and control predecessor
   1015   // visitation is determined by the given operand order; if compare(A, B) ==
   1016   // true, A is visited before B.
   1017   using CompareFunction =
   1018       std::function<bool(const HloInstruction*, const HloInstruction*)>;
   1019   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
   1020                                 const CompareFunction& operand_order,
   1021                                 bool call_finish_visit = true);
   1022 
   1023   // Performs a postorder DFS visit using this node as the root. Calls the given
   1024   // visitor function at each instruction.
   1025   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
   1026   Status Accept(
   1027       const std::function<Status(const HloInstruction*)>& visitor_func) const;
   1028 
   1029   // Visit this instruction and only this instruction with the given visitor.
   1030   template <typename HloInstructionPtr>
   1031   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
   1032 
   1033   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
   1034   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
   1035   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
   1036   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
   1037       const;
   1038 
   1039   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
   1040     auto rv =
   1041         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
   1042     return {const_cast<HloInstruction*>(rv.first), rv.second};
   1043   }
   1044 
   1045   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
   1046   const HloInstruction* LatestNonGteAncestor() const;
   1047 
   1048   HloInstruction* LatestNonGteAncestor() {
   1049     return const_cast<HloInstruction*>(
   1050         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
   1051   }
   1052 
   1053   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
   1054   // The setter should only be called by HloModule or HloComputation methods.
   1055   //
   1056   // Precondition: The instruction has a valid to_apply_ field.
   1057   HloComputation* to_apply() const;
   1058   void set_to_apply(HloComputation* to_apply);
   1059 
   1060   // Gets/sets the while_condition or while_body HloComputation for While. The
   1061   // setters should only be called by HloModule or HloComputation methods.
   1062   //
   1063   // Precondition: The instruction is a While instruction.
   1064   HloComputation* while_condition() const;
   1065   HloComputation* while_body() const;
   1066   void set_while_condition(HloComputation* while_condition);
   1067   void set_while_body(HloComputation* while_body);
   1068 
   1069   HloInstruction* while_init() const;
   1070 
   1071   // Gets/sets the true and false HloComputation for Conditional.
   1072   //
   1073   // Precondition: The instruction is a predicated Conditional instruction.
   1074   HloComputation* true_computation() const;
   1075   HloComputation* false_computation() const;
   1076 
   1077   // Gets the branch HloComputations for Conditional.
   1078   //
   1079   // Precondition: The instruction is a Conditional instruction.
   1080   const std::vector<HloComputation*>& branch_computations() const;
   1081   int branch_count() const;
   1082   HloComputation* branch_computation(int b) const;
   1083   // Sets a branch HloComputation for Conditional.
   1084   // The setter should only be called by HloModule or HloComputation methods.
   1085   //
   1086   // Precondition: The instruction is a Conditional instruction.
   1087   void set_branch_computation(int b, HloComputation* computation);
   1088 
   1089   // Returns a string for the signature of this instruction if considered as a
   1090   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
   1091   string SignatureString() const;
   1092 
   1093   // Returns a debugging string that represents this instruction.
   1094   //
   1095   // (We express the default options using an overload rather than a default
   1096   // param because gdb ignores default params, but does resolve overloads.)
   1097   //
   1098   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
   1099   // default, backing off on providing full information for very large strings,
   1100   // or provide a different name for a ToString-like function that does that.
   1101   string ToString() const { return ToString(HloPrintOptions()); }
   1102   string ToString(const HloPrintOptions& options) const;
   1103 
   1104   // Components of the ToString() representation:
   1105 
   1106   // Returns a string representation of the operand list.
   1107   string OperandsToString(const HloPrintOptions& options) const;
   1108 
   1109   // Returns string representation of op-specific attributes.
   1110   std::vector<string> ExtraAttributesToString(
   1111       const HloPrintOptions& options) const;
   1112 
   1113   // As ToString, but returns a shorter string.
   1114   string ToShortString() const;
   1115 
   1116   // Returns a serialized representation of this instruction.
   1117   virtual HloInstructionProto ToProto() const;
   1118 
   1119   // Returns a category for the HLO. This could be something like "convolution"
   1120   // or "elementwise".
   1121   virtual string ToCategory() const;
   1122 
   1123   // Returns a logging instruction, if the output of this instruction is logged.
   1124   //
   1125   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
   1126   HloInstruction* tracing() const;
   1127   void set_tracing(HloInstruction* trace_instruction);
   1128 
   1129   // Returns true if this instruction is fused, ie contained within a fusion
   1130   // instruction.
   1131   bool IsFused() const;
   1132 
   1133   // Returns true if this instruction can be legally fused into a fusion
   1134   // instruction.
   1135   bool IsFusible() const;
   1136 
   1137   // Returns the sharding applied to this operator.
   1138   // REQUIRES: has_sharding() is true.
   1139   const HloSharding& sharding() const {
   1140     CHECK(has_sharding());
   1141     return *sharding_;
   1142   }
   1143   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
   1144 
   1145   // Returns the sharding applied to this operator, or default_ if none exists.
   1146   const HloSharding& sharding_or_default(const HloSharding& default_) const {
   1147     return sharding_ ? *sharding_ : default_;
   1148   }
   1149   // Returns the sharding unique device, if any.
   1150   absl::optional<int64> sharding_unique_device() const {
   1151     if (sharding_ == nullptr) {
   1152       return absl::optional<int64>();
   1153     }
   1154     return sharding_->UniqueDevice();
   1155   }
   1156   // Sets the sharding of this operator. Should only be called by HloModule or
   1157   // HloComputation methods.
   1158   void set_sharding(const HloSharding& sharding) {
   1159     sharding_ = std::make_shared<const HloSharding>(sharding);
   1160   }
   1161   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
   1162     sharding_ = std::move(sharding);
   1163   }
   1164   void set_single_sharding(const HloSharding& sharding);
   1165   // Sets a sharding that assigns the current instruction to device.
   1166   void set_device_sharding(int64 device) {
   1167     set_single_sharding(HloSharding::AssignDevice(device));
   1168   }
   1169   // Remove any sharding from this operator.
   1170   void clear_sharding() { sharding_ = nullptr; }
   1171   // Return true if this operator has a sharding assigned.
   1172   bool has_sharding() const { return sharding_ != nullptr; }
   1173   // Checks whether the instruction has compatible sharding with the other
   1174   // instruction.
   1175   bool has_compatible_sharding(const HloInstruction* other) const {
   1176     if (!has_sharding()) {
   1177       return !other->has_sharding();
   1178     }
   1179     return other->has_sharding() ? sharding() == other->sharding() : false;
   1180   }
   1181 
   1182   // When creating a new instruction which either replaces, or shifts up (kCopy
   1183   // insertion case), another instruction, we need to make sure the certain
   1184   // properties of the new instruction are copied into the derived one. As of
   1185   // today, the metadata and sharding will be propagated to the derived
   1186   // instruction.
   1187   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
   1188 
   1189   // Clones the HLO instruction. The clone will have the same opcode, shape, and
   1190   // operands. After creation the clone has no uses. "this" (the instruction
   1191   // cloned from) is not changed. Suffix is the string to append to the name of
   1192   // the instruction to form the name of the cloned instruction.
   1193   // Ignores the control predecessors and successors of this HLO instruction.
   1194   std::unique_ptr<HloInstruction> Clone(
   1195       const string& suffix = "clone", HloCloneContext* context = nullptr) const;
   1196 
   1197   // Clones the HLO instruction as above but with new shape and operands.
   1198   std::unique_ptr<HloInstruction> CloneWithNewOperands(
   1199       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
   1200       HloCloneContext* context = nullptr) const;
   1201 
   1202   // Returns the computations this instruction directly calls (if any).
   1203   const std::vector<HloComputation*>& called_computations() const {
   1204     return called_computations_;
   1205   }
   1206 
   1207   // Replaces all called computations based on a map function. This is needed
   1208   // when we clone hlo_computations and want to let the instructions to point
   1209   // to the newly cloned nodes.
   1210   void ReplaceCalledComputations(
   1211       std::function<HloComputation*(HloComputation*)> map_function) {
   1212     for (int64 i = 0; i < called_computations_.size(); ++i) {
   1213       called_computations_[i] = map_function(called_computations_[i]);
   1214     }
   1215   }
   1216 
   1217   // Clears out the called computations.
   1218   //
   1219   // This is, in particular, necessary when inlining function bodies into their
   1220   // caller. If there were side-effecting operations in the called computations,
   1221   // the call itself is considered side-effecting and thus cannot be removed. By
   1222   // clearing out the computations, we reflect the fact that all side-effecting
   1223   // properties have been reflected in the caller, and make the call HLO
   1224   // removable.
   1225   void ClearCalledComputations() { called_computations_.clear(); }
   1226 
   1227   // Returns true if this instruction performs an elementwise operation on
   1228   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
   1229   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
   1230   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
   1231   //
   1232   // Note on performance: when this instruction is kFusion, this method, in the
   1233   // worst case, scans all fused instructions. We could speed this up by
   1234   // caching.
   1235   bool IsElementwiseOnOperand(int64 operand_idx) const;
   1236 
   1237   // Returns true if this instruction is elementwise on all its operands.
   1238   bool IsElementwise() const;
   1239 
   1240   // Returns true if this is a cross module all-reduce instruction.
   1241   bool IsCrossModuleAllReduce() const;
   1242 
   1243   // Returns true if this is a cross-replica all-reduce instruction.
   1244   bool IsCrossReplicaAllReduce() const;
   1245 
   1246   // Returns true if this instruction is binary and elementwise.
   1247   bool IsElementwiseBinary() const;
   1248 
   1249   // Returns whether this instruction may reuse elements of its `i`th operand.
   1250   bool ReusesOperandElements(int64 i) const {
   1251     return OperandElementUse(i) == UseKind::kReuse;
   1252   }
   1253 
   1254   // Returns the indices that the given operand appear in the operand list of
   1255   // this instruction. Note that an instruction can use the same operand
   1256   // multiple times.
   1257   std::vector<int64> OperandIndices(const HloInstruction* operand) const;
   1258 
   1259   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
   1260   // this reshape merely inserts or deletes 1-sized dimensions, return the input
   1261   // indices of the deleted dimensions and the output indices of the inserted
   1262   // dimensions.
   1263   //
   1264   // Precondition: this op must be a reshape.
   1265   std::tuple<bool, std::vector<int64>, std::vector<int64>>
   1266   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
   1267 
   1268   // Gets the string identifier for this instruction.
   1269   const string& name() const { return name_; }
   1270 
   1271   // Sets the string identifier for this instruction. Name will be sanitized to
   1272   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
   1273   void SetAndSanitizeName(const string& name) {
   1274     name_ = NameUniquer::GetSanitizedName(name);
   1275   }
   1276 
   1277   // Use the given NameUniquer to select a unique name for the instruction based
   1278   // on the instruction's existing name.
   1279   void UniquifyName(NameUniquer* name_uniquer);
   1280 
   1281   // Clear the unique ID of the instruction so that it can be re-assigned, such
   1282   // as for the purpose of compacting the instruction unique IDs.
   1283   void ClearUniqueIdInternal() { unique_id_ = -1; }
   1284 
   1285   // Set the unique id for this instruction to "id"
   1286   void SetUniqueId(int id) {
   1287     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
   1288     CHECK_GE(id, 0);
   1289     unique_id_ = id;
   1290   }
   1291 
   1292   // Return the unique ID assigned to this node via SetUniqueId (or -1
   1293   // if no id has been assigned yet).
   1294   int unique_id() const { return unique_id_; }
   1295 
   1296   // Returns the backend-specific configuration for how a backend should compile
   1297   // this HLO. The meaning of the field is backend specific. Not for use before
   1298   // or during general HLO optimization, since HLO optimizations do not preserve
   1299   // this field and they cannot interpret it due to its meaning being backend
   1300   // specific.
   1301   //
   1302   // ConfigProto should be a protobuf Message type.
   1303   template <typename ConfigProto>
   1304   StatusOr<ConfigProto> backend_config() const {
   1305     ConfigProto proto;
   1306     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
   1307     return std::move(proto);
   1308   }
   1309   Status set_backend_config(const tensorflow::protobuf::Message& proto);
   1310 
   1311   // Getter/setter for raw JSON-encoded backend config.  Prefer the
   1312   // functions above that deal in proto Messages where possible.
   1313   const string& raw_backend_config_string() const { return backend_config_; }
   1314   void set_raw_backend_config_string(string config_str) {
   1315     backend_config_ = std::move(config_str);
   1316   }
   1317 
   1318   bool is_default_config() const { return is_default_config_; }
   1319   void set_default_config() { is_default_config_ = true; }
   1320 
   1321   // Returns a string representation of a proto in the format used by
   1322   // raw_backend_config_string.
   1323   //
   1324   // This is morally equivalent to:
   1325   //
   1326   //   HloInstruction instr;
   1327   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
   1328   //   return instr.raw_backend_config_string();
   1329   //
   1330   static StatusOr<string> BackendConfigToRawString(
   1331       const tensorflow::protobuf::Message& proto);
   1332 
   1333   // Returns the information used to tell the implementation information about
   1334   // what sort of precision is requested. The meaning of the field is backend
   1335   // specific. At the moment, it is only supported for kConvolution and kDot.
   1336   // Transformations on one kDot or kConvolution to another will preserve this
   1337   // information. Transformations to other HLOs will not preserve this
   1338   // information but it is presumed that the alternate lowering is strictly
   1339   // superior.
   1340   // Precondition: opcode must be kConvolution or kDot.
   1341   const PrecisionConfig& precision_config() const;
   1342   PrecisionConfig* mutable_precision_config();
   1343 
   1344   // Sets the debug metadata for this instruction.
   1345   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
   1346   const OpMetadata& metadata() const { return metadata_; }
   1347 
   1348   // Set/get the computation containing this instruction. set_parent should only
   1349   // be called by HloComputation methods which add/remove instructions to
   1350   // computations.
   1351   void set_parent(HloComputation* computation) { parent_ = computation; }
   1352   const HloComputation* parent() const { return parent_; }
   1353   HloComputation* parent() { return parent_; }
   1354 
   1355   // Returns the module for this instruction.
   1356   HloModule* GetModule() const;
   1357 
   1358   // Returns whether we could assign input and output layouts to this
   1359   // instruction to make it a bitcast.
   1360   bool CouldBeBitcast() const;
   1361 
   1362   // Get/Set the number of partitions per outer dimension (in order, starting
   1363   // with outer-most dimension first). Currently used by the parallel cpu
   1364   // backend to partition HLOs into parallel tasks.
   1365   //
   1366   // TODO(b/62783254) Replace these methods with a more general way to
   1367   // annotate HLOs with backend-specific information.
   1368   const std::vector<int64>& outer_dimension_partitions() const {
   1369     return outer_dimension_partitions_;
   1370   }
   1371   void set_outer_dimension_partitions(
   1372       const std::vector<int64>& outer_dimension_partitions);
   1373 
   1374   // Old methods kept for smooth subclassing transition BEGIN.
   1375   // TODO(b/80131774): Remove this code.
   1376 
   1377   // Delegates to HloBatchNormInstruction::feature_index.
   1378   int64 feature_index() const;
   1379 
   1380   // Delegates to HloBatchNormInstruction::epsilon.
   1381   float epsilon() const;
   1382 
   1383   // Delegates to HloFftInstruction::fft_type.
   1384   FftType fft_type() const;
   1385 
   1386   // Delegates to HloFftInstruction::fft_length.
   1387   const std::vector<int64>& fft_length() const;
   1388 
   1389   // Delegates to HloSendRecvInstruction::channel_id.
   1390   int64 channel_id() const;
   1391 
   1392   // Returns the dimension sizes or numbers associated with this instruction.
   1393   virtual const std::vector<int64>& dimensions() const {
   1394     LOG(FATAL) << "Unimplemented method.";
   1395   }
   1396   virtual int64 dimensions(int64 index) const {
   1397     LOG(FATAL) << "Unimplemented method.";
   1398   }
   1399 
   1400   // Delegates to HloConcatenateInstruction::concatenate_dimension.
   1401   int64 concatenate_dimension() const;
   1402 
   1403   // Delegates to HloGetDimensionSizeInstruction::dimension.
   1404   int64 dimension() const;
   1405 
   1406   // Returns whether this instruction does a rank-2 transposition.
   1407   bool IsRank2Transpose() const;
   1408 
   1409   // Delegates to HloSliceInstruction::slice_start.
   1410   int64 slice_starts(int64 dimension) const;
   1411   const std::vector<int64>& slice_starts() const;
   1412 
   1413   // Delegates to HloSliceInstruction::slice_limits.
   1414   int64 slice_limits(int64 dimension) const;
   1415   const std::vector<int64>& slice_limits() const;
   1416 
   1417   // Delegates to HloSliceInstruction::slice_strides.
   1418   int64 slice_strides(int64 dimension) const;
   1419   const std::vector<int64>& slice_strides() const;
   1420 
   1421   // Returns the literal associated with this instruction.
   1422   const Literal& literal() const;
   1423 
   1424   // Returns whether the instruction is a constant.
   1425   bool IsConstant() const;
   1426 
   1427   // Delegate to HloConstantInstruction::RelayoutConstant.
   1428   void RelayoutConstant(const Layout& new_layout,
   1429                         const ShapeIndex& shape_index = {});
   1430 
   1431   // Delegates to HloTraceInstruction::TracingTag.
   1432   string TracingTag() const;
   1433 
   1434   // Delegates to HloFusionInstruction::AddFusionOperand.
   1435   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
   1436 
   1437   // Delegates to HloFusionInstruction::MergeFusionInstruction.
   1438   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
   1439 
   1440   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
   1441   void MergeFusionInstructionIntoMultiOutput(
   1442       HloInstruction* instruction_to_merge);
   1443 
   1444   // Delegates to HloFusionInstruction::FuseInstruction.
   1445   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
   1446 
   1447   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
   1448   HloInstruction* FuseInstructionIntoMultiOutput(
   1449       HloInstruction* instruction_to_fuse);
   1450 
   1451   // Delegates to HloFusionInstruction::fused_instruction.
   1452   HloComputation* fused_instructions_computation() const;
   1453 
   1454   // Delegates to HloFusionInstruction::fused_expression_root.
   1455   HloInstruction* fused_expression_root() const;
   1456 
   1457   // Delegates to HloFusionInstruction::fused_instructions.
   1458   const tensorflow::gtl::iterator_range<UnwrappingIterator<
   1459       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
   1460   fused_instructions() const;
   1461 
   1462   const tensorflow::gtl::iterator_range<
   1463       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
   1464   fused_instructions();
   1465 
   1466   // Delegates to HloFusionInstruction::fused_instruction_count.
   1467   int64 fused_instruction_count() const;
   1468 
   1469   // Delegates to HloFusionInstruction::fused_parameter.
   1470   HloInstruction* fused_parameter(int64 parameter_number) const;
   1471 
   1472   // Delegates to HloFusionInstruction::fused_parameters.
   1473   const std::vector<HloInstruction*>& fused_parameters() const;
   1474 
   1475   // Returns true if this instruction is a fusion instruction that generates
   1476   // multiple outputs.
   1477   const bool IsMultiOutputFusion() const;
   1478 
   1479   // Delegates to HloFusionInstruction::fusion_kind.
   1480   FusionKind fusion_kind() const;
   1481 
   1482   // Delegates to HloFusionInstruction::set_fusion_kind.
   1483   void set_fusion_kind(FusionKind kind);
   1484 
   1485   // Delegates to HloRngInstruction::random_distribution.
   1486   RandomDistribution random_distribution() const;
   1487 
   1488   // Delegates to HloParameterInstruction::parameter_number.
   1489   int64 parameter_number() const;
   1490 
   1491   // Delegates to
   1492   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
   1493   void set_parameter_replicated_at_leaf_buffers(
   1494       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
   1495 
   1496   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
   1497   const absl::optional<std::vector<bool>>&
   1498   parameter_replicated_at_leaf_buffers() const;
   1499 
   1500   // Delegates to HloGetTupleElementInstruction::tuple_index.
   1501   int64 tuple_index() const;
   1502 
   1503   // Delegates to HloReducePrecisionInstruction::exponent_bits.
   1504   int32 exponent_bits() const;
   1505 
   1506   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
   1507   int32 mantissa_bits() const;
   1508 
   1509   // Delegates to HloInfeedInstruction::infeed_config.
   1510   string infeed_config() const;
   1511 
   1512   // Delegates to HloInfeedInstruction::set_infeed_config.
   1513   void set_infeed_config(const string& config);
   1514 
   1515   // Returns the config for the Outfeed instruction.
   1516   const string& outfeed_config() const;
   1517 
   1518   // Returns the shape for the Outfeed instruction.
   1519   const Shape& outfeed_shape() const;
   1520 
   1521   // Delegates to HloCollectiveInstruction::replica_groups.
   1522   const std::vector<ReplicaGroup>& replica_groups() const;
   1523 
   1524   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
   1525   const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
   1526 
   1527   // Delegates to HloAllReduceInstruction::all_reduce_barrier.
   1528   string all_reduce_barrier() const;
   1529   void set_all_reduce_barrier(const string& barrier);
   1530 
   1531   // Delegates to HloAllReduceInstruction::all_reduce_id.
   1532   absl::optional<int64> all_reduce_id() const;
   1533   void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
   1534 
   1535   // Returns data on the window in a windowed operation such as
   1536   // convolution.
   1537   virtual const Window& window() const {
   1538     LOG(FATAL) << "Unimplemented method.";
   1539   }
   1540 
   1541   // Sets the window data in a windowed operation such as convolution.
   1542   virtual void set_window(const Window& window) {
   1543     LOG(FATAL) << "Unimplemented method.";
   1544   }
   1545 
   1546   // Returns data on the dimension numbers used for a convolution operation,
   1547   // which may be a kConvolution instruction or a kCustomCall that implements a
   1548   // convolution.
   1549   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
   1550 
   1551   // Sets the convolution dimension numbers on this instruction.  In general you
   1552   // shouldn't need to call this; instead, specify the convolution dimension
   1553   // numbers when you create the instruction.
   1554   void set_convolution_dimension_numbers(
   1555       const ConvolutionDimensionNumbers& dnums);
   1556 
   1557   // The number of feature groups. Must be a divisor of the input feature
   1558   // dimension and output feature dimension.
   1559   int64 feature_group_count() const;
   1560 
   1561   void set_feature_group_count(int64 feature_group_count);
   1562 
   1563   // The number of batch groups. Must be a divisor of the input batch dimension
   1564   int64 batch_group_count() const;
   1565 
   1566   void set_batch_group_count(int64 batch_group_count);
   1567 
   1568   // Delegates to HloSelectAndScatterInstruction::select.
   1569   HloComputation* select() const;
   1570 
   1571   // Delegates to HloSelectAndScatterInstruction::scatter.
   1572   HloComputation* scatter() const;
   1573 
   1574   // Delegates to HloSelectAndScatterInstruction::set_select.
   1575   void set_select(HloComputation* computation);
   1576 
   1577   // Delegates to HloSelectAndScatterInstruction::set_scatter.
   1578   void set_scatter(HloComputation* computation);
   1579 
   1580   // Delegates to HloCustomCallInstruction::custom_call_target.
   1581   const string& custom_call_target() const;
   1582 
   1583   // Delegates to HloPadInstruction::padding_config.
   1584   const PaddingConfig& padding_config() const;
   1585 
   1586   // Delegates to HloDynamicSliceInstruction::slice_sizes.
   1587   int64 slice_sizes(int64 dimension) const;
   1588 
   1589   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
   1590   const std::vector<int64>& dynamic_slice_sizes() const;
   1591 
   1592   // Delegates to HloGatherInstruction::gather_dimension_numbers.
   1593   const GatherDimensionNumbers& gather_dimension_numbers() const;
   1594   // Delegates to HloGatherInstruction::gather_slice_sizes.
   1595   absl::Span<const int64> gather_slice_sizes() const;
   1596 
   1597   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
   1598   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
   1599 
   1600   // Delegates to HloDotInstruction::dot_dimension_numbers().
   1601   const DotDimensionNumbers& dot_dimension_numbers() const;
   1602 
   1603   // Delegates to HloDomainInstruction::operand_side_metadata().
   1604   const DomainMetadata& operand_side_metadata() const;
   1605 
   1606   // Delegates to HloDomainInstruction::user_side_metadata().
   1607   const DomainMetadata& user_side_metadata() const;
   1608 
   1609   // Delegates to HloCompareInstruction::direction().
   1610   ComparisonDirection comparison_direction() const;
   1611 
   1612   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
   1613   const TriangularSolveOptions& triangular_solve_options() const;
   1614 
   1615   // Delegates to HloCholeskyInstruction::cholesky_options().
   1616   const CholeskyOptions& cholesky_options() const;
   1617 
   1618   // Old methods kept for smooth subclassing transition END.
   1619 
   1620  protected:
   1621   enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
   1622   // Helper class for computing OperandElementUse for kFusion.
   1623   class FusionReusesParamElements;
   1624 
   1625   // Internal constructor for a given opcode/shape, other fields must be filled
   1626   // by factory methods.
   1627   HloInstruction(HloOpcode opcode, const Shape& shape);
   1628 
   1629   // Appends operand to the list of operands and adds this instruction as a user
   1630   // of the operand.
   1631   void AppendOperand(HloInstruction* operand);
   1632 
   1633   void RemoveOperandAt(int index) {
   1634     operands_.erase(operands_.begin() + index);
   1635   }
   1636 
   1637   // Removes a list of operands with the given indices in ascending order.
   1638   void RemoveOperandsAtAscendingIndices(
   1639       absl::Span<const int> ascending_indices);
   1640 
   1641   void AppendComputation(HloComputation* computation) {
   1642     called_computations_.push_back(computation);
   1643   }
   1644 
   1645   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
   1646 
   1647   void set_called_computation(int index, HloComputation* computation) {
   1648     called_computations_[index] = computation;
   1649   }
   1650   // Indices of computations in called_computations_ for instructions which call
   1651   // multiple computations.
   1652   enum {
   1653     // kWhile computations.
   1654     kBodyComputationIndex = 0,
   1655     kConditionComputationIndex = 1,
   1656 
   1657     // kSelectAndScatter computations.
   1658     kSelectComputationIndex = 0,
   1659     kScatterComputationIndex = 1,
   1660 
   1661     // kConditional computations.
   1662     kTrueComputationIndex = 0,
   1663     kFalseComputationIndex = 1,
   1664   };
   1665 
   1666  private:
   1667   // Implementation for non-common logic of CloneWithNewOperands.
   1668   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
   1669       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
   1670       HloCloneContext* context) const {
   1671     // TODO(b/80131774): This should be pure virtual.
   1672     LOG(FATAL) << "Unimplemented method.";
   1673   }
   1674 
   1675   // Implementation for non-common logic of ExtraAttributesToString.
   1676   virtual std::vector<string> ExtraAttributesToStringImpl(
   1677       const HloPrintOptions& options) const {
   1678     return {};
   1679   }
   1680 
   1681   // Implementation for IsElementwise if operand_idx is nullopt and for
   1682   // IsElementwiseOnOperand if otherwise.
   1683   //
   1684   // NOTE: For all instructions other than kFusion, being elementwise on one of
   1685   // the operands is equivalent to being elementwise on all the operands.
   1686   virtual bool IsElementwiseImpl(
   1687       const absl::optional<int64>& operand_idx) const;
   1688   // Prints an instruction to a string.
   1689   //
   1690   // The canonical string representation needs to name operands and instruction
   1691   // names in a consistent way. This is implemented through the
   1692   // canonical_name_map.
   1693   string ToStringWithCanonicalNameMap(
   1694       const HloPrintOptions& options,
   1695       CanonicalNameMap* canonical_name_map) const;
   1696 
   1697   // Prints an operand to a string.
   1698   virtual string OperandsToStringWithCanonicalNameMap(
   1699       const HloPrintOptions& options,
   1700       CanonicalNameMap* canonical_name_map) const;
   1701 
   1702   // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and
   1703   // OperandsToStringWithCanonicalNameMap() functions.
   1704   friend class HloComputation;
   1705 
   1706   // See comments on Identical().
   1707   virtual bool IdenticalSlowPath(
   1708       const HloInstruction& other,
   1709       const std::function<bool(const HloComputation*, const HloComputation*)>&
   1710           eq_computations) const;
   1711 
   1712   // Generates a hash value specific to a particular type of an instruction.
   1713   // This function typically considers the inner root instruction.
   1714   virtual uint64 InnerHash() const;
   1715 
   1716   // Creates an n-ary elementwise operation.
   1717   static std::unique_ptr<HloInstruction> CreateNary(
   1718       const Shape& shape, HloOpcode opcode,
   1719       absl::Span<HloInstruction* const> operands);
   1720 
   1721   // Adds a user for this instruction.
   1722   void AddUser(HloInstruction* user);
   1723 
   1724   // Removes a user for this instruction.
   1725   void RemoveUser(HloInstruction* user);
   1726 
   1727   // Returns how this instruction uses elements of its `i`th operand.
   1728   UseKind OperandElementUse(int64 i) const;
   1729 
   1730   // Helper for implementing backend_config().  Parses backend_config_ into the
   1731   // given proto.
   1732   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
   1733 
   1734   int unique_id_;  // Unique to this HloInstruction within a HloModule
   1735 
   1736   // Opcode for this instruction.
   1737   HloOpcode opcode_;
   1738 
   1739   // Instruction operands.
   1740   InstructionVector operands_;
   1741 
   1742   // The set of control predecessors of this instruction.
   1743   // Note that the order of the instructions in the vector influences the order
   1744   // computed in HloComputation::ComputeInstructionPostOrder, which may
   1745   // influence the result of the compilation by changing the scheduling. We are
   1746   // not sure if it matters.
   1747   std::vector<HloInstruction*> control_predecessors_;
   1748 
   1749   // The users of this instruction. Users are HLOs where this instruction is an
   1750   // operand. The vector users_ and the set user_set_ contain identical
   1751   // members. The set enables fast membership testing and the vector enables
   1752   // fast, stable iteration.
   1753   std::vector<HloInstruction*> users_;
   1754   absl::flat_hash_set<const HloInstruction*> user_set_;
   1755 
   1756   // The set of control successors of this instruction.
   1757   std::vector<HloInstruction*> control_successors_;
   1758 
   1759   // The computation in which this instruction is contained.
   1760   HloComputation* parent_ = nullptr;
   1761 
   1762   // Result shape of this instruction.
   1763   Shape shape_;
   1764 
   1765   // The sharding, if one exists.
   1766   // Uses std::shared_ptr to allow reuse of the same sharding object between
   1767   // HloInstructions and other components as HloSharding can be very large for
   1768   // many element tuples.
   1769   std::shared_ptr<const HloSharding> sharding_;
   1770 
   1771   // Computations called by this instruction.
   1772   std::vector<HloComputation*> called_computations_;
   1773 
   1774   // A trace instruction that consumes this instruction.
   1775   //
   1776   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
   1777   // an operand.
   1778   HloInstruction* trace_instruction_ = nullptr;
   1779 
   1780   // The backend-specific configuration for how a backend should compile this
   1781   // HLO. See the documentation on backend_config().
   1782   string backend_config_;
   1783 
   1784   // This field is assigned to true when backend_config_ is assigned to
   1785   // a default configuration.
   1786   bool is_default_config_ = false;
   1787 
   1788   // String identifier for instruction.
   1789   string name_;
   1790 
   1791   // Metadata for debugging.
   1792   OpMetadata metadata_;
   1793 
   1794   // The number of partitions per outer dimension (listed in order from
   1795   // outer-most dimension first).
   1796   std::vector<int64> outer_dimension_partitions_;
   1797 
   1798   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
   1799 };
   1800 
   1801 // Explicit instantiations in hlo_instruction.cc.
   1802 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
   1803 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
   1804 
   1805 string ToString(HloInstruction::FusionKind kind);
   1806 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
   1807     const string& kind_name);
   1808 
   1809 // Custom (de)stringification functions for protos that live inside
   1810 // HloInstruction.
   1811 string PaddingConfigToString(const PaddingConfig& padding);
   1812 string OpMetadataToString(const OpMetadata& metadata);
   1813 string RandomDistributionToString(const RandomDistribution& distribution);
   1814 string PrecisionToString(const PrecisionConfig::Precision& precision);
   1815 string ConvolutionDimensionNumbersToString(
   1816     const ConvolutionDimensionNumbers& dnums);
   1817 
   1818 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
   1819 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
   1820 
   1821 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
   1822 
   1823 // Map classes that guarantee a deterministic iteration order when the key is
   1824 // an HloInstruction* or a const HloInstruction*.
   1825 // To make the iteration order over the map deterministic, the comparator
   1826 // should not be using the pointer values, but rather an intrinsic property of
   1827 // the hlo. Exception: null pointer values compare less than non-null.
   1828 struct HloPtrComparator {
   1829   bool operator()(const HloInstruction* const& lhs,
   1830                   const HloInstruction* const& rhs) const;
   1831 };
   1832 
   1833 template <typename ValueT>
   1834 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
   1835 
   1836 template <typename ValueT>
   1837 using ConstHloInstructionMap =
   1838     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
   1839 
   1840 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
   1841 using ConstHloInstructionSet =
   1842     std::set<const HloInstruction*, HloPtrComparator>;
   1843 
   1844 }  // namespace xla
   1845 
   1846 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
   1847