Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     17 
     18 #include <algorithm>
     19 #include <deque>
     20 #include <ostream>
     21 #include <set>
     22 #include <unordered_set>
     23 #include <utility>
     24 
     25 #include "tensorflow/compiler/xla/layout_util.h"
     26 #include "tensorflow/compiler/xla/literal_util.h"
     27 #include "tensorflow/compiler/xla/protobuf_util.h"
     28 #include "tensorflow/compiler/xla/ptr_util.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module.h"
     32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/status_macros.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/util.h"
     37 #include "tensorflow/compiler/xla/window_util.h"
     38 #include "tensorflow/core/lib/core/errors.h"
     39 #include "tensorflow/core/lib/gtl/flatmap.h"
     40 #include "tensorflow/core/lib/strings/str_util.h"
     41 #include "tensorflow/core/lib/strings/strcat.h"
     42 #include "tensorflow/core/platform/logging.h"
     43 
     44 namespace xla {
     45 
     46 using tensorflow::str_util::CEscape;
     47 using ::tensorflow::str_util::Join;
     48 using ::tensorflow::strings::StrAppend;
     49 using ::tensorflow::strings::StrCat;
     50 
     51 /* static */
     52 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
     53     HloModule* module, const HloInstructionProto& proto,
     54     const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
     55     const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
     56     const std::function<void(std::unique_ptr<HloComputation>)>&
     57         add_fused_computation) {
     58   TF_RET_CHECK(!proto.opcode().empty());
     59   TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
     60   TF_RET_CHECK(proto.has_shape());
     61 
     62   auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
     63   for (const string& operand_name : proto.operand_names()) {
     64     TF_RET_CHECK(ContainsKey(instruction_map, operand_name))
     65         << "No instruction named " << operand_name;
     66     instruction->AppendOperand(instruction_map.at(operand_name));
     67   }
     68   for (const string& predecessor_name : proto.control_predecessor_names()) {
     69     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name))
     70         << "No instruction named " << predecessor_name;
     71     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name)
     72                            ->AddControlDependencyTo(instruction.get()));
     73   }
     74 
     75   // In the proto, fused computations are held exclusively within the
     76   // HloInstructionProto and do not appear as an HloComputationProto within the
     77   // HloModuleProto.
     78   if (instruction->opcode() == HloOpcode::kFusion) {
     79     TF_RET_CHECK(proto.has_fused_instructions_computation());
     80     TF_RET_CHECK(!proto.fusion_kind().empty());
     81     TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
     82                         StringToFusionKind(proto.fusion_kind()));
     83     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
     84                         HloComputation::CreateFromProto(
     85                             module, proto.fused_instructions_computation(),
     86                             computation_map, add_fused_computation,
     87                             /*fusion_instruction=*/instruction.get()));
     88     instruction->called_computations_.push_back(fused_computation.get());
     89     add_fused_computation(std::move(fused_computation));
     90   } else {
     91     for (const string& computation_name : proto.called_computation_names()) {
     92       TF_RET_CHECK(ContainsKey(computation_map, computation_name))
     93           << "No computation named " << computation_name;
     94       instruction->called_computations_.push_back(
     95           computation_map.at(computation_name));
     96     }
     97   }
     98 
     99   TF_RET_CHECK(!proto.name().empty());
    100   instruction->name_ = proto.name();
    101 
    102   instruction->metadata_ = proto.metadata();
    103   if (proto.has_literal()) {
    104     TF_ASSIGN_OR_RETURN(instruction->literal_,
    105                         Literal::CreateFromProto(proto.literal()));
    106   }
    107   instruction->parameter_number_ = proto.parameter_number();
    108 
    109   instruction->tuple_index_ = proto.tuple_index();
    110   for (int64 dimension : proto.dimensions()) {
    111     instruction->dimensions_.push_back(dimension);
    112   }
    113   if (proto.has_window()) {
    114     instruction->window_ = MakeUnique<Window>(proto.window());
    115   }
    116   if (proto.has_convolution_dimension_numbers()) {
    117     instruction->convolution_dimension_numbers_ =
    118         MakeUnique<ConvolutionDimensionNumbers>(
    119             proto.convolution_dimension_numbers());
    120   }
    121   if (proto.has_dot_dimension_numbers()) {
    122     instruction->dot_dimension_numbers_ =
    123         MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
    124   }
    125   for (const HloInstructionProto::SliceDimensions& slice_dimensions :
    126        proto.slice_dimensions()) {
    127     instruction->slice_starts_.push_back(slice_dimensions.start());
    128     instruction->slice_limits_.push_back(slice_dimensions.limit());
    129     instruction->slice_strides_.push_back(slice_dimensions.stride());
    130   }
    131   instruction->exponent_bits_ = proto.exponent_bits();
    132   instruction->mantissa_bits_ = proto.mantissa_bits();
    133   for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
    134     instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size);
    135   }
    136   if (proto.has_padding_config()) {
    137     instruction->padding_config_ =
    138         MakeUnique<PaddingConfig>(proto.padding_config());
    139   }
    140   instruction->outfeed_config_ = proto.outfeed_config();
    141   instruction->distribution_ = proto.distribution();
    142   instruction->epsilon_ = proto.epsilon();
    143   instruction->feature_index_ = proto.feature_index();
    144   instruction->channel_id_ = proto.channel_id();
    145   instruction->infeed_config_ = proto.infeed_config();
    146   instruction->custom_call_target_ = proto.custom_call_target();
    147   instruction->outfeed_shape_ = proto.outfeed_shape();
    148   instruction->fft_type_ = proto.fft_type();
    149   for (int64 fft_len : proto.fft_length()) {
    150     instruction->fft_length_.push_back(fft_len);
    151   }
    152 
    153   return std::move(instruction);
    154 }
    155 
    156 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
    157     int64 parameter_number, const Shape& shape, const string& name) {
    158   auto instruction =
    159       WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
    160   instruction->parameter_number_ = parameter_number;
    161   instruction->name_ = name;
    162   return instruction;
    163 }
    164 
    165 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
    166     const string& tag, HloInstruction* operand) {
    167   auto instruction =
    168       WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
    169   instruction->operands_.push_back(operand);
    170   instruction->literal_ = Literal::CreateR1U8(tag);
    171   return instruction;
    172 }
    173 
    174 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
    175     std::unique_ptr<Literal> literal) {
    176   auto instruction =
    177       WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
    178   instruction->literal_ = std::move(literal);
    179   return instruction;
    180 }
    181 
    182 /* static */ std::unique_ptr<HloInstruction>
    183 HloInstruction::CreateGetTupleElement(const Shape& shape,
    184                                       HloInstruction* operand, int64 index) {
    185   auto instruction =
    186       WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
    187   instruction->tuple_index_ = index;
    188   instruction->AppendOperand(operand);
    189   return instruction;
    190 }
    191 
    192 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
    193     const Shape& shape, RandomDistribution distribution,
    194     tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
    195   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape));
    196   instruction->distribution_ = distribution;
    197   instruction->shape_ = shape;
    198   for (HloInstruction* param : parameters) {
    199     instruction->AppendOperand(param);
    200   }
    201   return instruction;
    202 }
    203 
    204 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
    205     const Shape& shape, HloOpcode opcode,
    206     tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
    207   if (opcode == HloOpcode::kCopy) {
    208     // It is impossible to copy an opaque shape, we don't know how big it is.
    209     CHECK(!ShapeUtil::IsOpaque(shape));
    210   }
    211   auto instruction = WrapUnique(new HloInstruction(opcode, shape));
    212   for (auto operand : operands) {
    213     instruction->AppendOperand(operand);
    214   }
    215   return instruction;
    216 }
    217 
    218 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
    219     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
    220   // Only certain opcodes are supported with CreateUnary: opcodes of unary
    221   // instructions with no auxiliary fields.
    222   switch (opcode) {
    223     case HloOpcode::kAbs:
    224     case HloOpcode::kRoundNearestAfz:
    225     case HloOpcode::kBitcast:
    226     case HloOpcode::kCeil:
    227     case HloOpcode::kCopy:
    228     case HloOpcode::kCos:
    229     case HloOpcode::kExp:
    230     case HloOpcode::kFloor:
    231     case HloOpcode::kImag:
    232     case HloOpcode::kIsFinite:
    233     case HloOpcode::kLog:
    234     case HloOpcode::kNot:
    235     case HloOpcode::kNegate:
    236     case HloOpcode::kReal:
    237     case HloOpcode::kSign:
    238     case HloOpcode::kSin:
    239     case HloOpcode::kSort:
    240     case HloOpcode::kTanh:
    241       break;
    242     default:
    243       LOG(FATAL) << "Invalid unary instruction opcode "
    244                  << HloOpcodeString(opcode);
    245   }
    246   return CreateNary(shape, opcode, {operand});
    247 }
    248 
    249 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
    250     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
    251     HloInstruction* rhs) {
    252   // Only certain opcodes are supported with CreateBinary: opcodes of binary
    253   // instructions with no auxiliary fields.
    254   switch (opcode) {
    255     case HloOpcode::kAdd:
    256     case HloOpcode::kAtan2:
    257     case HloOpcode::kDivide:
    258     case HloOpcode::kComplex:
    259     case HloOpcode::kDot:
    260     case HloOpcode::kEq:
    261     case HloOpcode::kGe:
    262     case HloOpcode::kGt:
    263     case HloOpcode::kLe:
    264     case HloOpcode::kLt:
    265     case HloOpcode::kMaximum:
    266     case HloOpcode::kMinimum:
    267     case HloOpcode::kMultiply:
    268     case HloOpcode::kNe:
    269     case HloOpcode::kPower:
    270     case HloOpcode::kRemainder:
    271     case HloOpcode::kSubtract:
    272     case HloOpcode::kAnd:
    273     case HloOpcode::kOr:
    274     case HloOpcode::kShiftLeft:
    275     case HloOpcode::kShiftRightArithmetic:
    276     case HloOpcode::kShiftRightLogical:
    277       break;
    278     default:
    279       LOG(FATAL) << "Invalid binary instruction opcode "
    280                  << HloOpcodeString(opcode);
    281   }
    282   return CreateNary(shape, opcode, {lhs, rhs});
    283 }
    284 
    285 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
    286     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
    287     HloInstruction* rhs, HloInstruction* ehs) {
    288   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
    289   // instructions with no auxiliary fields.
    290   switch (opcode) {
    291     case (HloOpcode::kClamp):
    292     case (HloOpcode::kSelect):
    293       break;
    294     default:
    295       LOG(FATAL) << "Invalid ternary instruction opcode "
    296                  << HloOpcodeString(opcode);
    297   }
    298   return CreateNary(shape, opcode, {lhs, rhs, ehs});
    299 }
    300 
    301 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
    302     const Shape& shape, HloOpcode opcode,
    303     tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
    304   CHECK_EQ(HloOpcode::kTuple, opcode);
    305   return CreateNary(shape, opcode, operands);
    306 }
    307 
    308 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
    309     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    310     HloComputation* map_computation,
    311     tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
    312   CHECK(static_operands.empty()) << "static_operands not yet supported";
    313   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
    314   for (auto operand : operands) {
    315     instruction->AppendOperand(operand);
    316   }
    317   instruction->called_computations_.push_back(map_computation);
    318   return instruction;
    319 }
    320 
    321 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
    322     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    323     const Window& window,
    324     const ConvolutionDimensionNumbers& dimension_numbers) {
    325   auto instruction =
    326       WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape));
    327   if (window_util::HasBaseDilation(window)) {
    328     instruction->name_ = instruction->name() + "-base-dilated";
    329   }
    330   if (window_util::HasWindowDilation(window)) {
    331     instruction->name_ = instruction->name() + "-window-dilated";
    332   }
    333   instruction->AppendOperand(lhs);
    334   instruction->AppendOperand(rhs);
    335   instruction->window_ = MakeUnique<Window>(window);
    336   instruction->convolution_dimension_numbers_ =
    337       MakeUnique<ConvolutionDimensionNumbers>(dimension_numbers);
    338   return instruction;
    339 }
    340 
    341 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
    342     const Shape& shape, HloInstruction* operand, FftType fft_type,
    343     tensorflow::gtl::ArraySlice<int64> fft_length) {
    344   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
    345   instruction->AppendOperand(operand);
    346   instruction->fft_type_ = fft_type;
    347   instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
    348   return instruction;
    349 }
    350 
    351 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
    352     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
    353     const DotDimensionNumbers& dimension_numbers) {
    354   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
    355   instruction->AppendOperand(lhs);
    356   instruction->AppendOperand(rhs);
    357   instruction->dot_dimension_numbers_ =
    358       MakeUnique<DotDimensionNumbers>(dimension_numbers);
    359   return instruction;
    360 }
    361 
    362 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
    363     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
    364   CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
    365   CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
    366 
    367   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
    368   instruction->AppendOperand(lhs);
    369   instruction->AppendOperand(rhs);
    370   instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>();
    371   instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
    372   instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
    373   return instruction;
    374 }
    375 
    376 /* static */ std::unique_ptr<HloInstruction>
    377 HloInstruction::CreateReducePrecision(const Shape& shape,
    378                                       HloInstruction* operand,
    379                                       const int exponent_bits,
    380                                       const int mantissa_bits) {
    381   auto instruction =
    382       WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
    383   instruction->AppendOperand(operand);
    384   instruction->exponent_bits_ = exponent_bits;
    385   instruction->mantissa_bits_ = mantissa_bits;
    386   return instruction;
    387 }
    388 
    389 /* static */ std::unique_ptr<HloInstruction>
    390 HloInstruction::CreateCrossReplicaSum(
    391     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
    392   return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands);
    393 }
    394 
    395 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
    396     const Shape& shape, const string& config) {
    397   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
    398   instruction->set_infeed_config(config);
    399   return instruction;
    400 }
    401 
    402 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
    403     const Shape& shape, HloInstruction* operand,
    404     tensorflow::StringPiece outfeed_config) {
    405   std::unique_ptr<HloInstruction> instruction =
    406       WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
    407   CHECK(ShapeUtil::Compatible(operand->shape(), shape))
    408       << "Outfeed shape " << shape << " must be compatible with operand shape "
    409       << operand->shape();
    410   instruction->AppendOperand(operand);
    411   instruction->outfeed_config_ = outfeed_config.ToString();
    412   instruction->outfeed_shape_ = shape;
    413   return instruction;
    414 }
    415 
    416 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
    417     HloInstruction* operand, int64 channel_id) {
    418   // Send instruction produces a tuple of {aliased operand, U32 context}.
    419   Shape output_shape = ShapeUtil::MakeTupleShape(
    420       {operand->shape(), ShapeUtil::MakeShape(U32, {})});
    421   auto instruction =
    422       WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
    423   instruction->AppendOperand(operand);
    424   instruction->channel_id_ = channel_id;
    425   return instruction;
    426 }
    427 
    428 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
    429     HloInstruction* operand) {
    430   CHECK(operand->opcode() == HloOpcode::kSend)
    431       << "SendDone must take the context operand from Send";
    432   auto instruction = WrapUnique(
    433       new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
    434   instruction->AppendOperand(operand);
    435   instruction->channel_id_ = operand->channel_id();
    436   return instruction;
    437 }
    438 
    439 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
    440     const Shape& shape, int64 channel_id) {
    441   // Recv instruction produces a tuple of {receive buffer, U32 context}.
    442   Shape output_shape =
    443       ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
    444   auto instruction =
    445       WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
    446   instruction->channel_id_ = channel_id;
    447   return instruction;
    448 }
    449 
    450 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
    451     HloInstruction* operand) {
    452   CHECK(operand->opcode() == HloOpcode::kRecv)
    453       << "RecvDone must take the context operand from Recv";
    454   Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
    455   auto instruction =
    456       WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
    457   instruction->AppendOperand(operand);
    458   instruction->channel_id_ = operand->channel_id();
    459   return instruction;
    460 }
    461 
    462 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
    463     const Shape& shape, HloInstruction* operand,
    464     tensorflow::gtl::ArraySlice<int64> dimensions) {
    465   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
    466   instruction->AppendOperand(operand);
    467   instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
    468   return instruction;
    469 }
    470 
    471 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
    472     const Shape& shape, HloComputation* condition, HloComputation* body,
    473     HloInstruction* init) {
    474   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
    475   instruction->AppendOperand(init);
    476   // Body comes before condition computation in the vector.
    477   instruction->called_computations_.push_back(body);
    478   instruction->called_computations_.push_back(condition);
    479   return instruction;
    480 }
    481 
    482 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
    483     const Shape& shape, HloInstruction* pred,
    484     HloInstruction* true_computation_arg, HloComputation* true_computation,
    485     HloInstruction* false_computation_arg, HloComputation* false_computation) {
    486   auto instruction =
    487       WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
    488   instruction->AppendOperand(pred);
    489   instruction->AppendOperand(true_computation_arg);
    490   instruction->AppendOperand(false_computation_arg);
    491   // In called_computations_, the index of true_computation must be 0 and that
    492   // of false computation must be 1, as defined by kTrueComputationIndex and
    493   // kFalseComputationIndex.
    494   instruction->called_computations_.push_back(true_computation);
    495   instruction->called_computations_.push_back(false_computation);
    496   return instruction;
    497 }
    498 
    499 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
    500     const Shape& shape, HloInstruction* operand,
    501     tensorflow::gtl::ArraySlice<int64> start_indices,
    502     tensorflow::gtl::ArraySlice<int64> limit_indices,
    503     tensorflow::gtl::ArraySlice<int64> strides) {
    504   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
    505   instruction->AppendOperand(operand);
    506   instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
    507   instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
    508   instruction->slice_strides_.assign(strides.begin(), strides.end());
    509   // For backward compatibility with old serialized computations: if there are
    510   // no strides, assume all strides are 1.
    511   // TODO(b/63317920): remove this code.
    512   if (instruction->slice_strides_.empty()) {
    513     instruction->slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
    514   }
    515   return instruction;
    516 }
    517 
    518 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
    519     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
    520     tensorflow::gtl::ArraySlice<int64> slice_sizes) {
    521   auto instruction =
    522       WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape));
    523   instruction->AppendOperand(operand);
    524   instruction->AppendOperand(start_indices);
    525   instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(),
    526                                            slice_sizes.end());
    527   return instruction;
    528 }
    529 
    530 /* static */ std::unique_ptr<HloInstruction>
    531 HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
    532                                          HloInstruction* operand,
    533                                          HloInstruction* update,
    534                                          HloInstruction* start_indices) {
    535   auto instruction =
    536       WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
    537   instruction->AppendOperand(operand);
    538   instruction->AppendOperand(update);
    539   instruction->AppendOperand(start_indices);
    540   return instruction;
    541 }
    542 
    543 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
    544     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    545     int64 dimension) {
    546   auto instruction =
    547       WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
    548   for (auto operand : operands) {
    549     instruction->AppendOperand(operand);
    550   }
    551   instruction->dimensions_.push_back(dimension);
    552   return instruction;
    553 }
    554 
    555 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
    556     const Shape& shape, HloInstruction* operand) {
    557   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
    558   instruction->AppendOperand(operand);
    559   return instruction;
    560 }
    561 
    562 /* static */ std::unique_ptr<HloInstruction>
    563 HloInstruction::CreateBitcastConvert(const Shape& shape,
    564                                      HloInstruction* operand) {
    565   auto instruction =
    566       WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
    567   instruction->AppendOperand(operand);
    568   return instruction;
    569 }
    570 
    571 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
    572     const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
    573     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
    574     HloComputation* reduce_computation) {
    575   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
    576   instruction->AppendOperand(arg);
    577   instruction->AppendOperand(init_value);
    578   instruction->dimensions_.assign(dimensions_to_reduce.begin(),
    579                                   dimensions_to_reduce.end());
    580   instruction->called_computations_.push_back(reduce_computation);
    581   return instruction;
    582 }
    583 
    584 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
    585     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
    586     const Window& window, HloComputation* reduce_computation) {
    587   auto instruction =
    588       WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
    589   instruction->AppendOperand(operand);
    590   instruction->AppendOperand(init_value);
    591   instruction->called_computations_.push_back(reduce_computation);
    592   instruction->window_ = MakeUnique<Window>(window);
    593   return instruction;
    594 }
    595 
    596 /* static */ std::unique_ptr<HloInstruction>
    597 HloInstruction::CreateBatchNormTraining(const Shape& shape,
    598                                         HloInstruction* operand,
    599                                         HloInstruction* scale,
    600                                         HloInstruction* offset, float epsilon,
    601                                         int64 feature_index) {
    602   auto instruction =
    603       WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape));
    604   instruction->AppendOperand(operand);
    605   instruction->AppendOperand(scale);
    606   instruction->AppendOperand(offset);
    607   instruction->epsilon_ = epsilon;
    608   instruction->feature_index_ = feature_index;
    609   return instruction;
    610 }
    611 
    612 /* static */ std::unique_ptr<HloInstruction>
    613 HloInstruction::CreateBatchNormInference(
    614     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
    615     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
    616     float epsilon, int64 feature_index) {
    617   auto instruction =
    618       WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
    619   instruction->AppendOperand(operand);
    620   instruction->AppendOperand(scale);
    621   instruction->AppendOperand(offset);
    622   instruction->AppendOperand(mean);
    623   instruction->AppendOperand(variance);
    624   instruction->epsilon_ = epsilon;
    625   instruction->feature_index_ = feature_index;
    626   return instruction;
    627 }
    628 
    629 /* static */ std::unique_ptr<HloInstruction>
    630 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
    631                                     HloInstruction* scale, HloInstruction* mean,
    632                                     HloInstruction* variance,
    633                                     HloInstruction* grad_output, float epsilon,
    634                                     int64 feature_index) {
    635   auto instruction =
    636       WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape));
    637   instruction->AppendOperand(operand);
    638   instruction->AppendOperand(scale);
    639   instruction->AppendOperand(mean);
    640   instruction->AppendOperand(variance);
    641   instruction->AppendOperand(grad_output);
    642   instruction->epsilon_ = epsilon;
    643   instruction->feature_index_ = feature_index;
    644   return instruction;
    645 }
    646 
    647 /* static */ std::unique_ptr<HloInstruction>
    648 HloInstruction::CreateSelectAndScatter(
    649     const Shape& shape, HloInstruction* operand, HloComputation* select,
    650     const Window& window, HloInstruction* source, HloInstruction* init_value,
    651     HloComputation* scatter) {
    652   auto instruction =
    653       WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape));
    654   instruction->AppendOperand(operand);
    655   instruction->AppendOperand(source);
    656   instruction->AppendOperand(init_value);
    657   // Select comes before scatter in the vector.
    658   instruction->called_computations_.push_back(select);
    659   instruction->called_computations_.push_back(scatter);
    660   instruction->window_ = MakeUnique<Window>(window);
    661   return instruction;
    662 }
    663 
    664 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
    665     const Shape& shape, HloInstruction* operand,
    666     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
    667   auto instruction =
    668       WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
    669   instruction->AppendOperand(operand);
    670   instruction->dimensions_.assign(broadcast_dimensions.begin(),
    671                                   broadcast_dimensions.end());
    672   return instruction;
    673 }
    674 
    675 /* static */ std::unique_ptr<HloInstruction>
    676 HloInstruction::CreateBroadcastSequence(
    677     const Shape& output_shape, HloInstruction* operand,
    678     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
    679         adder) {
    680   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
    681         ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
    682   Shape broadcast_shape = ShapeUtil::ChangeElementType(
    683       output_shape, operand->shape().element_type());
    684   // Do explicit broadcast for scalar.
    685   if (ShapeUtil::IsScalar(operand->shape())) {
    686     auto broadcast =
    687         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
    688     broadcast->set_metadata(operand->metadata());
    689     if (operand->has_sharding()) {
    690       broadcast->set_sharding(operand->sharding());
    691     }
    692     return broadcast;
    693   }
    694   // Do explicit broadcast for degenerate broadcast.
    695   std::vector<int64> broadcast_dimensions;
    696   std::vector<int64> reshaped_dimensions;
    697   for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
    698     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
    699       broadcast_dimensions.push_back(i);
    700       reshaped_dimensions.push_back(operand->shape().dimensions(i));
    701     } else {
    702       CHECK_EQ(operand->shape().dimensions(i), 1)
    703           << "An explicit broadcast sequence requires the broadcasted "
    704              "dimensions to be trivial; operand: "
    705           << operand->ToString() << "; output_shape: " << output_shape;
    706     }
    707   }
    708   // Eliminate the size one dimensions.
    709   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
    710       ShapeUtil::MakeShape(operand->shape().element_type(),
    711                            reshaped_dimensions),
    712       operand));
    713   reshaped_operand->set_metadata(operand->metadata());
    714   if (operand->has_sharding()) {
    715     reshaped_operand->set_sharding(operand->sharding());
    716   }
    717   // Broadcast 'reshape' up to the larger size.
    718   auto broadcast = HloInstruction::CreateBroadcast(
    719       broadcast_shape, reshaped_operand, broadcast_dimensions);
    720   broadcast->set_metadata(operand->metadata());
    721   if (operand->has_sharding()) {
    722     broadcast->set_sharding(operand->sharding());
    723   }
    724   return broadcast;
    725 }
    726 
    727 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
    728     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
    729     const PaddingConfig& padding_config) {
    730   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape));
    731   instruction->AppendOperand(operand);
    732   instruction->AppendOperand(padding_value);
    733   instruction->padding_config_ = MakeUnique<PaddingConfig>(padding_config);
    734   return instruction;
    735 }
    736 
    737 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
    738     const Shape& shape, HloInstruction* operand) {
    739   CHECK_EQ(ShapeUtil::ElementsIn(shape),
    740            ShapeUtil::ElementsIn(operand->shape()))
    741       << "shape: " << ShapeUtil::HumanString(shape)
    742       << " operand: " << ShapeUtil::HumanString(operand->shape());
    743   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
    744   instruction->AppendOperand(operand);
    745   return instruction;
    746 }
    747 
    748 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
    749     const Shape& shape, HloInstruction* operand,
    750     tensorflow::gtl::ArraySlice<int64> dimensions) {
    751   CHECK_EQ(shape.dimensions().size(), dimensions.size());
    752   CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
    753   CHECK(std::equal(operand->shape().dimensions().begin(),
    754                    operand->shape().dimensions().end(),
    755                    Permute(dimensions, shape.dimensions()).begin()))
    756       << "shape: " << ShapeUtil::HumanString(shape)
    757       << ", operand->shape(): " << ShapeUtil::HumanString(shape)
    758       << ", dimensions: {" << Join(dimensions, ", ") << "}";
    759   auto instruction =
    760       WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
    761   instruction->AppendOperand(operand);
    762   instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
    763   return instruction;
    764 }
    765 
    766 // We put the fusion kind into the instruction's name for transpose-dot fusions,
    767 // since those fusions are really just describing a type of dot rather than
    768 // generating a novel computation.
    769 static string FusionNodeName(HloInstruction::FusionKind fusion_kind) {
    770   switch (fusion_kind) {
    771     case HloInstruction::FusionKind::kTransposeDot:
    772       return "dot_fusion";
    773     default:
    774       return "fusion";
    775   }
    776 }
    777 
    778 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
    779     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
    780   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
    781   instruction->fusion_kind_ = fusion_kind;
    782   instruction->name_ = FusionNodeName(fusion_kind);
    783   instruction->set_parent(fused_root->parent());
    784   instruction->set_metadata(fused_root->metadata());
    785   instruction->CloneAndFuseInternal(fused_root);
    786   return instruction;
    787 }
    788 
    789 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
    790     const Shape& shape, FusionKind fusion_kind,
    791     tensorflow::gtl::ArraySlice<HloInstruction*> operands,
    792     HloComputation* fusion_computation) {
    793   auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
    794   for (auto operand : operands) {
    795     instruction->AppendOperand(operand);
    796   }
    797   instruction->fusion_kind_ = fusion_kind;
    798   instruction->name_ = FusionNodeName(fusion_kind);
    799   instruction->called_computations_.push_back(fusion_computation);
    800   fusion_computation->SetFusionInstruction(instruction.get());
    801   return instruction;
    802 }
    803 
    804 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
    805   CHECK_EQ(opcode(), HloOpcode::kFusion);
    806   CHECK_EQ(operand_count(),
    807            fused_instructions_computation()->parameter_instructions().size());
    808   const int64 param_no = operand_count();
    809   // Name the parameter after the instruction it represents in the outer
    810   // (non-fusion) computation.
    811   string param_name = StrCat(new_operand->name(), ".param_", param_no);
    812   HloInstruction* fused_parameter =
    813       fused_instructions_computation()->AddParameter(
    814           HloInstruction::CreateParameter(param_no, new_operand->shape(),
    815                                           param_name));
    816   AppendOperand(new_operand);
    817   return fused_parameter;
    818 }
    819 
    820 void HloInstruction::MergeFusionInstruction(
    821     HloInstruction* instruction_to_merge) {
    822   CHECK_EQ(opcode_, HloOpcode::kFusion);
    823   CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
    824   CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
    825         operands().end());
    826   // Clone the instruction from which to merge fused instructions.
    827   std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone();
    828   // Replace uses of fused parameters with the corresponding operand of the
    829   // fusion.  Add all non-parameter fused instructions to 'unfused_instructions'
    830   // to be merged into 'this'.  This is done in reverse post order.
    831   std::vector<HloInstruction*> unfused_instructions;
    832   auto fused_instructions =
    833       clone->fused_instructions_computation()->MakeInstructionPostOrder();
    834   for (auto fused_it = fused_instructions.rbegin();
    835        fused_it != fused_instructions.rend(); ++fused_it) {
    836     auto fused_instruction = *fused_it;
    837     if (fused_instruction->opcode() == HloOpcode::kParameter) {
    838       TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith(
    839           clone->mutable_operand(fused_instruction->parameter_number())));
    840     } else {
    841       unfused_instructions.push_back(fused_instruction);
    842     }
    843   }
    844   CHECK(unfused_instructions.front() == clone->fused_expression_root());
    845   // Replace instruction_to_merge use of 'this' with unfused_root.
    846   TF_CHECK_OK(
    847       instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
    848   // Fuse 'unfused_instructions' into 'this'.
    849   for (auto& instruction : unfused_instructions) {
    850     FuseInstruction(instruction);
    851     instruction->DetachFromOperands();
    852   }
    853   CHECK_EQ(0, clone->user_count());
    854   clone->DetachFromOperands();
    855   TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
    856       clone->fused_instructions_computation()));
    857 }
    858 
    859 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
    860     HloInstruction* instruction_to_merge) {
    861   CHECK_EQ(opcode_, HloOpcode::kFusion);
    862   CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
    863   // Add all non-parameter fused instructions to 'unfused_instructions' to be
    864   // merged into 'this'. `old_to_new' maps the instructions in the fused node
    865   // to the disaseembled fusion instructions.
    866   // Note that we add the unfused instructions to this->parent_ computation.
    867   // This is necessary because the unique_id needs for an instruction and
    868   // it's only added when inserting to the computation.
    869   tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
    870   std::vector<HloInstruction*> unfused_instructions;
    871   auto computation_to_merge =
    872       instruction_to_merge->fused_instructions_computation();
    873   auto post_order = computation_to_merge->MakeInstructionPostOrder();
    874   for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
    875     auto fused_instruction = *rit;
    876     if (fused_instruction->opcode() == HloOpcode::kParameter) {
    877       InsertOrDie(&old_to_new, fused_instruction,
    878                   instruction_to_merge->mutable_operand(
    879                       fused_instruction->parameter_number()));
    880       continue;
    881     }
    882 
    883     // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
    884     // which clones again. This can be improved.
    885     auto cloned_instruction =
    886         parent_->AddInstruction(fused_instruction->Clone());
    887     unfused_instructions.push_back(cloned_instruction);
    888     InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
    889   }
    890   for (auto unfused_instruction : unfused_instructions) {
    891     for (int64 index = 0; index < unfused_instruction->operand_count();
    892          index++) {
    893       auto new_operand =
    894           FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
    895       TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
    896     }
    897   }
    898 
    899   HloInstruction* unfused_root = unfused_instructions.front();
    900   TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
    901 
    902   TF_CHECK_OK(
    903       instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
    904   if (GetModule()) {
    905     TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
    906   }
    907 
    908   // Fuse the root instruction and generate multiple outputs.
    909   FuseInstructionIntoMultiOutput(unfused_root);
    910   TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
    911   // The rest instructions are of normal fusing.
    912   for (int64 i = 1; i < unfused_instructions.size(); i++) {
    913     auto instruction = unfused_instructions[i];
    914     FuseInstruction(instruction);
    915     TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
    916   }
    917 }
    918 
    919 HloInstruction* HloInstruction::FuseInstructionInternal(
    920     HloInstruction* instruction_to_fuse, bool add_output) {
    921   CHECK_EQ(opcode_, HloOpcode::kFusion);
    922 
    923   // When add_output is false, this fusion instruction must be a user of
    924   // instruction_to_fuse.
    925   if (!add_output) {
    926     CHECK(IsUserOf(instruction_to_fuse));
    927   }
    928   HloInstruction* fused_instruction =
    929       CloneAndFuseInternal(instruction_to_fuse, add_output);
    930   return fused_instruction;
    931 }
    932 
    933 HloInstruction* HloInstruction::CloneAndFuseInternal(
    934     HloInstruction* instruction_to_fuse, bool add_output) {
    935   CHECK_EQ(opcode_, HloOpcode::kFusion);
    936   CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
    937   VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
    938   HloInstruction* clone = nullptr;
    939   if (called_computations_.empty()) {
    940     // New fusion instruction. It should not be a multioutput instruction.
    941     CHECK(!add_output);
    942     auto builder = HloComputation::Builder("fused_computation", this);
    943     builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
    944     called_computations_.push_back(
    945         CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
    946     clone = fused_expression_root();
    947   } else {
    948     clone = fused_instructions_computation()->AddInstruction(
    949         instruction_to_fuse->Clone(/*suffix=*/""));
    950     // When add_output is false, instruction_to_fuse is necessarily an operand
    951     // of the fusion instruction. After fusion this will no longer be the case.
    952     // Remove the operand from the operand list and remove its corresponding
    953     // fused parameter instruction. Renumber parameters as necessary to make
    954     // parameter numbers consistent with their index in the
    955     // fused_parameter_ vector.
    956     bool in_operand_list = std::find(operands_.begin(), operands_.end(),
    957                                      instruction_to_fuse) != operands_.end();
    958     CHECK(add_output || in_operand_list);
    959     const std::vector<HloInstruction*>& fused_parameters =
    960         fused_instructions_computation()->parameter_instructions();
    961     for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
    962       if (instruction_to_fuse == operands_[operand_num]) {
    963         // replace the fused parameter instruction's uses with the clone.
    964         HloInstruction* fused_parameter = fused_parameters[operand_num];
    965         TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
    966 
    967         // Remove the corresponding fused parameter and operand from their
    968         // respective vectors.
    969         TF_CHECK_OK(
    970             fused_instructions_computation()->RemoveParameter(operand_num));
    971         operands_.erase(operands_.begin() + operand_num);
    972         break;
    973       }
    974     }
    975     // We've cloned instruction_to_fuse into this fusion instruction, so this
    976     // fusion instruction is no longer a use of instruction_to_fuse.
    977     if (in_operand_list) {
    978       instruction_to_fuse->RemoveUser(this);
    979       // When the instruction_to_fuse does not have other users, we don't need
    980       // to generate a multioutput fusion instruction.
    981       if (instruction_to_fuse->user_count() == 0) {
    982         add_output = false;
    983       }
    984     }
    985   }
    986 
    987   // Reread the parameters in the computation.
    988   const std::vector<HloInstruction*>& fused_parameters =
    989       fused_instructions_computation()->parameter_instructions();
    990 
    991   // Add each operand of the clone as an operand of the fusion instruction. A
    992   // complication is that some clone operands may already be operands of the
    993   // fusion instruction.
    994   for (int64 operand_num = 0; operand_num < clone->operand_count();
    995        ++operand_num) {
    996     HloInstruction* operand = clone->mutable_operand(operand_num);
    997 
    998     // See if this operand is already an operand of the fusion node.
    999     CHECK_EQ(operands_.size(), fused_parameters.size());
   1000     HloInstruction* fused_param = nullptr;
   1001     for (int64 i = 0; i < operands_.size(); ++i) {
   1002       if (operands_[i] == operand) {
   1003         fused_param = fused_parameters[i];
   1004         break;
   1005       }
   1006     }
   1007 
   1008     if (fused_param == nullptr) {
   1009       // Clone's operand was not already an operand of the fusion
   1010       // instruction. Add it as an operand and add a corresponding fused
   1011       // parameter instruction.
   1012       fused_param = AddFusionOperand(operand);
   1013     }
   1014     TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
   1015   }
   1016 
   1017   if (add_output) {
   1018     CHECK_GT(instruction_to_fuse->user_count(), 0);
   1019     // If this is already a multioutput fusion instruction, expand the root
   1020     // tuple by 1.
   1021     HloInstruction* fused_root = fused_expression_root();
   1022     HloInstruction::InstructionVector tuple_elements;
   1023     bool newly_created_tuple_instr = false;
   1024     if (fused_root->opcode() == HloOpcode::kTuple) {
   1025       tuple_elements = fused_root->operands();
   1026     } else {
   1027       tuple_elements.push_back(fused_root);
   1028       newly_created_tuple_instr = true;
   1029     }
   1030     if (clone->opcode() == HloOpcode::kTuple) {
   1031       for (auto inst : clone->operands()) {
   1032         tuple_elements.push_back(inst);
   1033       }
   1034     } else {
   1035       tuple_elements.push_back(clone);
   1036     }
   1037     HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
   1038         HloInstruction::CreateTuple(tuple_elements));
   1039     fused_instructions_computation()->set_root_instruction(new_root);
   1040     shape_ = new_root->shape();
   1041     if (fused_root->opcode() == HloOpcode::kTuple) {
   1042       TF_CHECK_OK(
   1043           fused_instructions_computation()->RemoveInstruction(fused_root));
   1044     }
   1045 
   1046     // If this is a newly created multioutput instruction, we need to update
   1047     // the use of the original fusion instruction.
   1048     if (newly_created_tuple_instr) {
   1049       HloInstruction* new_instr = parent_->AddInstruction(
   1050           HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
   1051       TF_CHECK_OK(ReplaceAllUsesWith(new_instr));
   1052     }
   1053     int64 index = tuple_elements.size();
   1054     if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
   1055       index -= instruction_to_fuse->operand_count();
   1056       std::vector<HloInstruction*> to_be_removed;
   1057       for (auto old_gte : instruction_to_fuse->users()) {
   1058         CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
   1059         int64 old_tuple_index = old_gte->tuple_index();
   1060         HloInstruction* new_gte =
   1061             parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
   1062                 old_gte->shape(), this, index + old_tuple_index));
   1063         TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
   1064         to_be_removed.push_back(old_gte);
   1065       }
   1066       for (auto old_gte : to_be_removed) {
   1067         TF_CHECK_OK(parent_->RemoveInstruction(old_gte));
   1068       }
   1069       TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
   1070     } else {
   1071       HloInstruction* new_gte =
   1072           parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
   1073               clone->shape(), this, index - 1));
   1074       TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
   1075     }
   1076   }
   1077 
   1078   VLOG(2) << "New clone:\n" << clone->ToString();
   1079   return clone;
   1080 }
   1081 
   1082 RandomDistribution HloInstruction::random_distribution() const {
   1083   CHECK_EQ(opcode_, HloOpcode::kRng);
   1084   return distribution_;
   1085 }
   1086 
   1087 bool HloInstruction::HasSideEffect() const {
   1088   switch (opcode_) {
   1089     case HloOpcode::kSend:
   1090     case HloOpcode::kSendDone:
   1091     case HloOpcode::kRecv:
   1092     case HloOpcode::kRecvDone:
   1093     case HloOpcode::kRng:
   1094     case HloOpcode::kInfeed:
   1095     case HloOpcode::kOutfeed:
   1096     case HloOpcode::kTrace:
   1097     case HloOpcode::kHostCompute:
   1098       return true;
   1099     default: {
   1100       // Check if any of the called computations has a side effect.
   1101       for (const auto& computation : called_computations()) {
   1102         if (computation->HasSideEffect()) {
   1103           return true;
   1104         }
   1105       }
   1106       return false;
   1107     }
   1108   }
   1109 }
   1110 
   1111 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
   1112     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1113     HloComputation* computation) {
   1114   std::unique_ptr<HloInstruction> instruction =
   1115       WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
   1116   for (auto operand : operands) {
   1117     instruction->AppendOperand(operand);
   1118   }
   1119   instruction->called_computations_.push_back(computation);
   1120   return instruction;
   1121 }
   1122 
   1123 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
   1124     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1125     tensorflow::StringPiece custom_call_target) {
   1126   std::unique_ptr<HloInstruction> instruction =
   1127       WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape));
   1128   for (auto operand : operands) {
   1129     instruction->AppendOperand(operand);
   1130   }
   1131   instruction->custom_call_target_ = custom_call_target.ToString();
   1132   return instruction;
   1133 }
   1134 
   1135 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
   1136     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1137     tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
   1138   std::unique_ptr<HloInstruction> instruction =
   1139       WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape));
   1140   for (auto operand : operands) {
   1141     instruction->AppendOperand(operand);
   1142   }
   1143   instruction->channel_name_ = channel_name.ToString();
   1144   instruction->cost_estimate_ns_ = cost_estimate_ns;
   1145   return instruction;
   1146 }
   1147 
   1148 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
   1149     tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
   1150   std::vector<Shape> element_shapes;
   1151   for (auto element : elements) {
   1152     element_shapes.push_back(element->shape());
   1153   }
   1154   Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
   1155   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
   1156 }
   1157 
   1158 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
   1159     const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
   1160     const GatherDimensionNumbers& gather_dim_numbers,
   1161     tensorflow::gtl::ArraySlice<int64> window_bounds) {
   1162   std::unique_ptr<HloInstruction> instruction =
   1163       WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
   1164   instruction->AppendOperand(operand);
   1165   instruction->AppendOperand(gather_indices);
   1166   instruction->gather_dimension_numbers_ =
   1167       MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
   1168   c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
   1169   return instruction;
   1170 }
   1171 
   1172 /* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
   1173     tensorflow::gtl::ArraySlice<int64> output_window_dims,
   1174     tensorflow::gtl::ArraySlice<int64> elided_window_dims,
   1175     tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) {
   1176   GatherDimensionNumbers gather_dim_numbers;
   1177   for (int64 output_window_dim : output_window_dims) {
   1178     gather_dim_numbers.add_output_window_dims(output_window_dim);
   1179   }
   1180   for (int64 elided_window_dim : elided_window_dims) {
   1181     gather_dim_numbers.add_elided_window_dims(elided_window_dim);
   1182   }
   1183   for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
   1184     gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
   1185   }
   1186 
   1187   return gather_dim_numbers;
   1188 }
   1189 
   1190 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
   1191     const Shape& shape,
   1192     tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
   1193     HloModule* module) const {
   1194   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
   1195   VLOG(3) << "  new operands:";
   1196   for (const HloInstruction* new_operand : new_operands) {
   1197     VLOG(3) << "    %" << new_operand->name();
   1198   }
   1199 
   1200   std::unique_ptr<HloInstruction> clone;
   1201 
   1202   // Explicitly call the factory for the instruction type. This is more robust
   1203   // in the face of code changes than copying fields explicitly. This also
   1204   // properly sets the user fields of the operands.
   1205   switch (opcode_) {
   1206     // Unary ops.
   1207     case HloOpcode::kAbs:
   1208     case HloOpcode::kRoundNearestAfz:
   1209     case HloOpcode::kBitcast:
   1210     case HloOpcode::kCeil:
   1211     case HloOpcode::kCopy:
   1212     case HloOpcode::kCos:
   1213     case HloOpcode::kExp:
   1214     case HloOpcode::kImag:
   1215     case HloOpcode::kIsFinite:
   1216     case HloOpcode::kFloor:
   1217     case HloOpcode::kLog:
   1218     case HloOpcode::kNot:
   1219     case HloOpcode::kNegate:
   1220     case HloOpcode::kReal:
   1221     case HloOpcode::kSign:
   1222     case HloOpcode::kSin:
   1223     case HloOpcode::kSort:
   1224     case HloOpcode::kTanh:
   1225       CHECK_EQ(new_operands.size(), 1);
   1226       clone = CreateUnary(shape, opcode_, new_operands[0]);
   1227       break;
   1228     // Binary ops.
   1229     case HloOpcode::kAdd:
   1230     case HloOpcode::kAtan2:
   1231     case HloOpcode::kComplex:
   1232     case HloOpcode::kDivide:
   1233     case HloOpcode::kMultiply:
   1234     case HloOpcode::kSubtract:
   1235     case HloOpcode::kEq:
   1236     case HloOpcode::kGe:
   1237     case HloOpcode::kGt:
   1238     case HloOpcode::kLe:
   1239     case HloOpcode::kLt:
   1240     case HloOpcode::kNe:
   1241     case HloOpcode::kMaximum:
   1242     case HloOpcode::kMinimum:
   1243     case HloOpcode::kPower:
   1244     case HloOpcode::kRemainder:
   1245     case HloOpcode::kAnd:
   1246     case HloOpcode::kOr:
   1247     case HloOpcode::kShiftLeft:
   1248     case HloOpcode::kShiftRightArithmetic:
   1249     case HloOpcode::kShiftRightLogical:
   1250       CHECK_EQ(new_operands.size(), 2);
   1251       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
   1252       break;
   1253     // Ternary ops.
   1254     case HloOpcode::kClamp:
   1255     case HloOpcode::kSelect:
   1256       CHECK_EQ(new_operands.size(), 3);
   1257       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
   1258                             new_operands[2]);
   1259       break;
   1260     // Other supported ops.
   1261     case HloOpcode::kBroadcast:
   1262       CHECK_EQ(new_operands.size(), 1);
   1263       clone = CreateBroadcast(shape, new_operands[0], dimensions_);
   1264       break;
   1265     case HloOpcode::kCall:
   1266       clone = CreateCall(shape, new_operands, to_apply());
   1267       break;
   1268     case HloOpcode::kCustomCall:
   1269       clone = CreateCustomCall(shape, new_operands, custom_call_target_);
   1270       break;
   1271     case HloOpcode::kHostCompute:
   1272       clone = CreateHostCompute(shape, new_operands, channel_name_,
   1273                                 cost_estimate_ns_);
   1274       break;
   1275     case HloOpcode::kConcatenate:
   1276       clone = CreateConcatenate(shape, new_operands, dimensions(0));
   1277       break;
   1278     case HloOpcode::kConvert:
   1279       CHECK_EQ(new_operands.size(), 1);
   1280       clone = CreateConvert(shape, new_operands[0]);
   1281       break;
   1282     case HloOpcode::kBitcastConvert:
   1283       CHECK_EQ(new_operands.size(), 1);
   1284       clone = CreateBitcastConvert(shape, new_operands[0]);
   1285       break;
   1286     case HloOpcode::kReducePrecision:
   1287       CHECK_EQ(new_operands.size(), 1);
   1288       clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
   1289                                     mantissa_bits_);
   1290       break;
   1291     case HloOpcode::kConvolution:
   1292       CHECK_EQ(new_operands.size(), 2);
   1293       clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
   1294                              *convolution_dimension_numbers_);
   1295       break;
   1296     case HloOpcode::kDot:
   1297       CHECK_EQ(new_operands.size(), 2);
   1298       clone = CreateDot(shape, new_operands[0], new_operands[1],
   1299                         *dot_dimension_numbers_);
   1300       break;
   1301     case HloOpcode::kFft:
   1302       CHECK_EQ(new_operands.size(), 1);
   1303       return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
   1304     case HloOpcode::kCrossReplicaSum:
   1305       clone = CreateCrossReplicaSum(shape, new_operands);
   1306       break;
   1307     case HloOpcode::kGetTupleElement:
   1308       CHECK_EQ(new_operands.size(), 1);
   1309       clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
   1310       break;
   1311     case HloOpcode::kMap:
   1312       clone = CreateMap(shape, new_operands, to_apply());
   1313       break;
   1314     case HloOpcode::kPad:
   1315       CHECK_EQ(new_operands.size(), 2);
   1316       clone =
   1317           CreatePad(shape, new_operands[0], new_operands[1], *padding_config_);
   1318       break;
   1319     case HloOpcode::kReduce:
   1320       CHECK_EQ(new_operands.size(), 2);
   1321       clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_,
   1322                            to_apply());
   1323       break;
   1324     case HloOpcode::kReduceWindow:
   1325       CHECK_EQ(new_operands.size(), 2);
   1326       clone = CreateReduceWindow(shape, new_operands[0], new_operands[1],
   1327                                  *window_, to_apply());
   1328       break;
   1329     case HloOpcode::kSelectAndScatter:
   1330       CHECK_EQ(new_operands.size(), 3);
   1331       clone =
   1332           CreateSelectAndScatter(shape, new_operands[0], select(), *window_,
   1333                                  new_operands[1], new_operands[2], scatter());
   1334       break;
   1335     case HloOpcode::kReverse:
   1336       CHECK_EQ(new_operands.size(), 1);
   1337       clone = CreateReverse(shape, new_operands[0], dimensions_);
   1338       break;
   1339     case HloOpcode::kRng:
   1340       clone = CreateRng(shape, distribution_, new_operands);
   1341       break;
   1342     case HloOpcode::kReshape:
   1343       CHECK_EQ(new_operands.size(), 1);
   1344       clone = CreateReshape(shape, new_operands[0]);
   1345       break;
   1346     case HloOpcode::kSlice:
   1347       CHECK_EQ(new_operands.size(), 1);
   1348       clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
   1349                           slice_strides_);
   1350       break;
   1351     case HloOpcode::kDynamicSlice:
   1352       clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1],
   1353                                  dynamic_slice_sizes_);
   1354       break;
   1355     case HloOpcode::kDynamicUpdateSlice:
   1356       CHECK_EQ(new_operands.size(), 3);
   1357       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
   1358                                        new_operands[2]);
   1359       break;
   1360     case HloOpcode::kTranspose:
   1361       CHECK_EQ(new_operands.size(), 1);
   1362       clone = CreateTranspose(shape, new_operands[0], dimensions_);
   1363       break;
   1364     case HloOpcode::kTuple:
   1365       clone = CreateTuple(new_operands);
   1366       *clone->mutable_shape() = shape;
   1367       break;
   1368     case HloOpcode::kWhile:
   1369       CHECK_EQ(new_operands.size(), 1);
   1370       clone =
   1371           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
   1372       break;
   1373     case HloOpcode::kConstant:
   1374       clone = CreateConstant(literal_->CloneToUnique());
   1375       break;
   1376     case HloOpcode::kFusion:
   1377       clone = CloneFusionWithNewOperands(shape, new_operands, module);
   1378       break;
   1379     case HloOpcode::kParameter:
   1380       clone = CreateParameter(parameter_number_, shape, name_);
   1381       break;
   1382     case HloOpcode::kBatchNormTraining:
   1383       CHECK_EQ(new_operands.size(), 3);
   1384       clone =
   1385           CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
   1386                                   new_operands[2], epsilon(), feature_index());
   1387       break;
   1388     case HloOpcode::kBatchNormInference:
   1389       CHECK_EQ(new_operands.size(), 5);
   1390       clone = CreateBatchNormInference(
   1391           shape, new_operands[0], new_operands[1], new_operands[2],
   1392           new_operands[3], new_operands[4], epsilon(), feature_index());
   1393       break;
   1394     case HloOpcode::kInfeed:
   1395       CHECK_EQ(new_operands.size(), 0);
   1396       clone = CreateInfeed(shape, infeed_config());
   1397       break;
   1398     case HloOpcode::kOutfeed:
   1399       CHECK_EQ(new_operands.size(), 1);
   1400       clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
   1401       break;
   1402     case HloOpcode::kBatchNormGrad:
   1403       CHECK_EQ(new_operands.size(), 5);
   1404       clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
   1405                                   new_operands[2], new_operands[3],
   1406                                   new_operands[4], epsilon(), feature_index());
   1407       break;
   1408     case HloOpcode::kConditional:
   1409       CHECK_EQ(new_operands.size(), 3);
   1410       clone = CreateConditional(shape, new_operands[0], new_operands[1],
   1411                                 true_computation(), new_operands[2],
   1412                                 false_computation());
   1413       break;
   1414     case HloOpcode::kSend:
   1415       CHECK_EQ(new_operands.size(), 1);
   1416       clone = CreateSend(new_operands[0], channel_id());
   1417       break;
   1418     case HloOpcode::kSendDone:
   1419       CHECK_EQ(new_operands.size(), 1);
   1420       clone = CreateSendDone(new_operands[0]);
   1421       break;
   1422     case HloOpcode::kRecv:
   1423       CHECK_EQ(new_operands.size(), 0);
   1424       // The shape is a tuple, but CreateRecv() wants the raw data shape.
   1425       clone =
   1426           CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
   1427       break;
   1428     case HloOpcode::kRecvDone:
   1429       CHECK_EQ(new_operands.size(), 1);
   1430       clone = CreateRecvDone(new_operands[0]);
   1431       break;
   1432     case HloOpcode::kGather:
   1433       CHECK_EQ(new_operands.size(), 2);
   1434       clone = CreateGather(shape, new_operands[0], new_operands[1],
   1435                            *gather_dimension_numbers_, gather_window_bounds_);
   1436       break;
   1437     case HloOpcode::kTrace:
   1438       LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
   1439   }
   1440   clone->set_metadata(metadata_);
   1441   if (has_sharding()) {
   1442     clone->set_sharding(sharding());
   1443   }
   1444   clone->set_parent(parent_);
   1445   return clone;
   1446 }
   1447 
   1448 HloInstruction::~HloInstruction() {}
   1449 
   1450 std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
   1451                                                       HloModule* module) const {
   1452   std::unique_ptr<HloInstruction> clone =
   1453       CloneWithNewOperands(shape_, operands_, module);
   1454   if (suffix.empty()) {
   1455     clone->name_ = name();
   1456   } else {
   1457     // If an instruction is cloned multiple times avoid names like
   1458     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
   1459     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
   1460     // clone of foo.suffix2 is named foo.suffix3 and so on.
   1461     const string dot_suffix = "." + suffix;
   1462     size_t index = name().rfind(dot_suffix);
   1463     if (index == string::npos) {
   1464       // Existing name does not include ".suffix".
   1465       clone->name_ = name() + dot_suffix;
   1466     } else {
   1467       // Existing name includes ".suffix". Determine if substring after
   1468       // ".suffix" is numeric and should be replaced with an incremented number.
   1469       string after_suffix = name().substr(index + dot_suffix.size());
   1470       if (after_suffix.empty()) {
   1471         // Existing name ends in ".suffix". New name should end in ".suffix2".
   1472         clone->name_ = name() + "2";
   1473       } else {
   1474         // If names ends with .suffix[0-9]+ then replace with a suffix with the
   1475         // numeric value incremented.
   1476         int64 numeric_suffix;
   1477         if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) {
   1478           clone->name_ =
   1479               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
   1480         } else {
   1481           // Substring after ".suffix" is non-numeric.
   1482           clone->name_ = name() + dot_suffix;
   1483         }
   1484       }
   1485     }
   1486   }
   1487   return clone;
   1488 }
   1489 
   1490 std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
   1491     const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
   1492     HloModule* module) const {
   1493   CHECK_EQ(opcode_, HloOpcode::kFusion);
   1494   CHECK(parent() != nullptr);
   1495 
   1496   auto new_instruction =
   1497       WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
   1498   // Add the operands to our new fusion instruction.
   1499   for (HloInstruction* new_operand : operands) {
   1500     new_instruction->AppendOperand(new_operand);
   1501   }
   1502   // Clone all the fused instructions for the new fusion instruction.
   1503   HloInstructionMap<HloInstruction*> old_to_new;
   1504   std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
   1505   // Create the list of fused parameters by mapping through the cloned,
   1506   // fused instructions.
   1507   for (HloInstruction* old_fused_parameter :
   1508        fused_instructions_computation()->parameter_instructions()) {
   1509     new_fused_instructions.push_back(
   1510         old_fused_parameter->Clone("clone", module));
   1511     HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
   1512     InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
   1513   }
   1514   for (auto old_fused_instruction :
   1515        fused_instructions_computation()->MakeInstructionPostOrder()) {
   1516     if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
   1517       FindOrDie(old_to_new, old_fused_instruction);
   1518       continue;
   1519     }
   1520     std::vector<HloInstruction*> new_operands;
   1521     for (int64 operand_idx = 0;
   1522          operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
   1523       HloInstruction* old_operand =
   1524           old_fused_instruction->mutable_operand(operand_idx);
   1525       new_operands.push_back(FindOrDie(old_to_new, old_operand));
   1526     }
   1527     new_fused_instructions.push_back(
   1528         old_fused_instruction->CloneWithNewOperands(
   1529             old_fused_instruction->shape(), new_operands, module));
   1530     HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
   1531     new_fused_instruction->set_parent(parent_);
   1532     InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
   1533   }
   1534   new_instruction->fusion_kind_ = fusion_kind_;
   1535   auto computation_builder = HloComputation::Builder(
   1536       fused_instructions_computation()->name() + ".clone",
   1537       new_instruction.get());
   1538   // We iterated the fusion instructions in reverse post order which means
   1539   // that we must reverse our new list of fusion instructions.
   1540   for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
   1541        new_fused_instruction_iter != new_fused_instructions.rend();
   1542        ++new_fused_instruction_iter) {
   1543     computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
   1544   }
   1545   if (module == nullptr) {
   1546     module = GetModule();
   1547   }
   1548   auto fused_root_ = fused_expression_root();
   1549   new_instruction->called_computations_.push_back(
   1550       CHECK_NOTNULL(module)->AddEmbeddedComputation(
   1551           computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
   1552   return new_instruction;
   1553 }
   1554 
   1555 std::pair<const HloInstruction*, ShapeIndex>
   1556 HloInstruction::LatestNonGteAncestorAndIndex() const {
   1557   const HloInstruction* hlo = this;
   1558   ShapeIndex index;
   1559   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
   1560     index.push_back(hlo->tuple_index());
   1561     hlo = hlo->operand(0);
   1562   }
   1563 
   1564   // We built up index in the reverse order from what we want.
   1565   std::reverse(index.begin(), index.end());
   1566 
   1567   return {hlo, index};
   1568 }
   1569 
   1570 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
   1571   const HloInstruction* hlo = this;
   1572   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
   1573     hlo = hlo->operand(0);
   1574   }
   1575   return hlo;
   1576 }
   1577 
   1578 const Literal& HloInstruction::literal() const {
   1579   CHECK_EQ(HloOpcode::kConstant, opcode_);
   1580   return *literal_;
   1581 }
   1582 
   1583 bool HloInstruction::CanHaveDimensionsField() const {
   1584   return (opcode() == HloOpcode::kReverse ||
   1585           opcode() == HloOpcode::kConcatenate ||
   1586           opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
   1587           opcode() == HloOpcode::kTranspose);
   1588 }
   1589 
   1590 const std::vector<int64>& HloInstruction::dimensions() const {
   1591   CHECK(CanHaveDimensionsField());
   1592   return dimensions_;
   1593 }
   1594 
   1595 int64 HloInstruction::dimensions(int64 index) const {
   1596   return dimensions()[index];
   1597 }
   1598 
   1599 int64 HloInstruction::concatenate_dimension() const {
   1600   CHECK(opcode() == HloOpcode::kConcatenate);
   1601   CHECK_EQ(1, dimensions_.size());
   1602   return dimensions(0);
   1603 }
   1604 
   1605 int64 HloInstruction::tuple_index() const {
   1606   CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
   1607   return tuple_index_;
   1608 }
   1609 
   1610 const HloInstruction* HloInstruction::operand(int64 i) const {
   1611   return operands_[i];
   1612 }
   1613 
   1614 HloInstruction* HloInstruction::mutable_operand(int64 i) {
   1615   CHECK(operands_[i] != nullptr);
   1616   return operands_[i];
   1617 }
   1618 
   1619 int64 HloInstruction::operand_index(const HloInstruction* target) const {
   1620   for (int64 i = 0; i < operand_count(); ++i) {
   1621     if (target == operand(i)) {
   1622       return i;
   1623     }
   1624   }
   1625   LOG(FATAL) << "target was not an operand: " << target->ToString();
   1626 }
   1627 
   1628 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
   1629   TF_RET_CHECK(instruction->parent() == parent());
   1630   if (std::find(control_successors_.begin(), control_successors_.end(),
   1631                 instruction) == control_successors_.end()) {
   1632     control_successors_.push_back(instruction);
   1633     TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
   1634                            instruction->control_predecessors_.end(),
   1635                            this) == instruction->control_predecessors_.end());
   1636     instruction->control_predecessors_.push_back(this);
   1637   }
   1638   return Status::OK();
   1639 }
   1640 
   1641 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
   1642   auto succ_it = std::find(control_successors_.begin(),
   1643                            control_successors_.end(), instruction);
   1644   TF_RET_CHECK(succ_it != control_successors_.end());
   1645   control_successors_.erase(succ_it);
   1646   auto pred_it = std::find(instruction->control_predecessors_.begin(),
   1647                            instruction->control_predecessors_.end(), this);
   1648   TF_RET_CHECK(pred_it != instruction->control_predecessors_.end());
   1649   instruction->control_predecessors_.erase(pred_it);
   1650 
   1651   return Status::OK();
   1652 }
   1653 
   1654 void HloInstruction::AppendOperand(HloInstruction* operand) {
   1655   operands_.push_back(operand);
   1656   operand->AddUser(this);
   1657 }
   1658 
   1659 void HloInstruction::AddUser(HloInstruction* user) {
   1660   if (!ContainsKey(user_set_, user)) {
   1661     user_set_.insert(user);
   1662     users_.push_back(user);
   1663   }
   1664 }
   1665 
   1666 bool HloInstruction::IsConstant() const {
   1667   return opcode_ == HloOpcode::kConstant;
   1668 }
   1669 
   1670 bool HloInstruction::HasConstantOperand() const {
   1671   for (const HloInstruction* operand : operands_) {
   1672     if (operand->IsConstant()) {
   1673       return true;
   1674     }
   1675   }
   1676   return false;
   1677 }
   1678 
   1679 bool HloInstruction::IdenticalSlowPath(
   1680     const HloInstruction& other,
   1681     const std::function<bool(const HloComputation*, const HloComputation*)>&
   1682         eq_computations,
   1683     const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const {
   1684   // Perform opcode specific checks.
   1685   switch (opcode()) {
   1686     // The result of these instructions only depend upon their opcode and
   1687     // operands.
   1688     case HloOpcode::kAbs:
   1689     case HloOpcode::kAtan2:
   1690     case HloOpcode::kRoundNearestAfz:
   1691     case HloOpcode::kAdd:
   1692     case HloOpcode::kCeil:
   1693     case HloOpcode::kClamp:
   1694     case HloOpcode::kComplex:
   1695     case HloOpcode::kCopy:
   1696     case HloOpcode::kCos:
   1697     case HloOpcode::kCrossReplicaSum:
   1698     case HloOpcode::kDivide:
   1699     case HloOpcode::kEq:
   1700     case HloOpcode::kExp:
   1701     case HloOpcode::kFloor:
   1702     case HloOpcode::kGe:
   1703     case HloOpcode::kGt:
   1704     case HloOpcode::kImag:
   1705     case HloOpcode::kIsFinite:
   1706     case HloOpcode::kLe:
   1707     case HloOpcode::kLog:
   1708     case HloOpcode::kAnd:
   1709     case HloOpcode::kNot:
   1710     case HloOpcode::kOr:
   1711     case HloOpcode::kLt:
   1712     case HloOpcode::kMaximum:
   1713     case HloOpcode::kMinimum:
   1714     case HloOpcode::kMultiply:
   1715     case HloOpcode::kNe:
   1716     case HloOpcode::kNegate:
   1717     case HloOpcode::kPower:
   1718     case HloOpcode::kReal:
   1719     case HloOpcode::kRemainder:
   1720     case HloOpcode::kSelect:
   1721     case HloOpcode::kShiftLeft:
   1722     case HloOpcode::kShiftRightArithmetic:
   1723     case HloOpcode::kShiftRightLogical:
   1724     case HloOpcode::kSign:
   1725     case HloOpcode::kSin:
   1726     case HloOpcode::kSubtract:
   1727     case HloOpcode::kTanh:
   1728     case HloOpcode::kTuple:
   1729       return true;
   1730 
   1731     case HloOpcode::kFusion:
   1732       return fusion_kind() == other.fusion_kind() &&
   1733              eq_computations(fused_instructions_computation(),
   1734                              other.fused_instructions_computation());
   1735 
   1736     // These opcodes have complex or special behavior so just return false.
   1737     case HloOpcode::kRng:
   1738     case HloOpcode::kTrace:
   1739     case HloOpcode::kWhile:
   1740       return false;
   1741 
   1742     case HloOpcode::kParameter:
   1743       return parameter_number() == other.parameter_number() &&
   1744              // Check the shape too because `this` and `other` may be in
   1745              // different HloComputations.
   1746              eq_shapes(shape(), other.shape());
   1747 
   1748     case HloOpcode::kBatchNormTraining:
   1749     case HloOpcode::kBatchNormInference:
   1750     case HloOpcode::kBatchNormGrad:
   1751       return feature_index() == other.feature_index() &&
   1752              epsilon() == other.epsilon();
   1753 
   1754     // A constant is defined by the value in the literal.
   1755     case HloOpcode::kConstant:
   1756       return literal() == other.literal();
   1757 
   1758     // A convert result is determined by the primitive type that the operand is
   1759     // converted into.
   1760     case HloOpcode::kConvert:
   1761     case HloOpcode::kBitcastConvert:
   1762       return shape().element_type() == other.shape().element_type();
   1763 
   1764     // A reduce-precision operation is determined by the bit sizes.
   1765     case HloOpcode::kReducePrecision:
   1766       return exponent_bits() == other.exponent_bits() &&
   1767              mantissa_bits() == other.mantissa_bits();
   1768 
   1769     // Convolution has a window and dimensions.
   1770     case HloOpcode::kConvolution:
   1771       return protobuf_util::ProtobufEquals(window(), other.window()) &&
   1772              protobuf_util::ProtobufEquals(
   1773                  convolution_dimension_numbers(),
   1774                  other.convolution_dimension_numbers());
   1775     // Check dot dimension numbers.
   1776     case HloOpcode::kDot:
   1777       return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
   1778                                            other.dot_dimension_numbers());
   1779 
   1780     case HloOpcode::kGather:
   1781       return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
   1782                                            other.gather_dimension_numbers()) &&
   1783              gather_window_bounds() == other.gather_window_bounds();
   1784 
   1785     // FFT has various types & lengths.
   1786     case HloOpcode::kFft:
   1787       return fft_type() == other.fft_type() &&
   1788              fft_length() == other.fft_length();
   1789 
   1790     // Reduction results are determined by the reduction dimension and the
   1791     // reduction computation.
   1792     case HloOpcode::kReduce:
   1793       return dimensions() == other.dimensions() &&
   1794              eq_computations(to_apply(), other.to_apply());
   1795     case HloOpcode::kReduceWindow:
   1796       return eq_computations(to_apply(), other.to_apply()) &&
   1797              protobuf_util::ProtobufEquals(window(), other.window());
   1798 
   1799     // SelectAndScatter is determined by both select and scatter
   1800     // computation as well as the window configuration.
   1801     case HloOpcode::kSelectAndScatter:
   1802       return eq_computations(select(), other.select()) &&
   1803              eq_computations(scatter(), other.scatter()) &&
   1804              protobuf_util::ProtobufEquals(window(), other.window());
   1805 
   1806     case HloOpcode::kReshape:
   1807       return eq_shapes(shape(), other.shape());
   1808 
   1809     // Transpose result is determined by the final shape and the permutation.
   1810     case HloOpcode::kTranspose:
   1811       return eq_shapes(shape(), other.shape()) &&
   1812              dimensions() == other.dimensions();
   1813 
   1814     // Remaining instructions with special values.
   1815     case HloOpcode::kBitcast:
   1816       return eq_shapes(shape(), other.shape());
   1817     case HloOpcode::kBroadcast:
   1818       return eq_shapes(shape(), other.shape()) &&
   1819              dimensions() == other.dimensions();
   1820     case HloOpcode::kConcatenate:
   1821       return dimensions() == other.dimensions();
   1822     case HloOpcode::kGetTupleElement:
   1823       return tuple_index() == other.tuple_index();
   1824     case HloOpcode::kPad:
   1825       return protobuf_util::ProtobufEquals(padding_config(),
   1826                                            other.padding_config());
   1827     case HloOpcode::kSlice:
   1828       return slice_starts_ == other.slice_starts_ &&
   1829              slice_limits_ == other.slice_limits_ &&
   1830              slice_strides_ == other.slice_strides_;
   1831     case HloOpcode::kDynamicSlice:
   1832       return eq_shapes(shape(), other.shape()) &&
   1833              dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
   1834     case HloOpcode::kDynamicUpdateSlice:
   1835       return eq_shapes(shape(), other.shape());
   1836     case HloOpcode::kCall:
   1837     case HloOpcode::kMap:
   1838       return eq_computations(to_apply(), other.to_apply());
   1839     case HloOpcode::kCustomCall:
   1840       return custom_call_target_ == other.custom_call_target_;
   1841     case HloOpcode::kReverse:
   1842       return dimensions() == other.dimensions();
   1843     case HloOpcode::kConditional:
   1844       return eq_computations(true_computation(), other.true_computation()) &&
   1845              eq_computations(false_computation(), other.false_computation());
   1846 
   1847     // These opcodes are not yet supported.
   1848     case HloOpcode::kInfeed:
   1849     case HloOpcode::kOutfeed:
   1850     case HloOpcode::kSort:
   1851     case HloOpcode::kRecv:
   1852     case HloOpcode::kRecvDone:
   1853     case HloOpcode::kSend:
   1854     case HloOpcode::kSendDone:
   1855     case HloOpcode::kHostCompute:
   1856       return false;
   1857   }
   1858 }
   1859 
   1860 bool HloInstruction::IsRank2Transpose() const {
   1861   return (opcode_ == HloOpcode::kTranspose) &&
   1862          dimensions_ == std::vector<int64>({1, 0}) &&
   1863          shape_.dimensions_size() == 2 &&
   1864          std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
   1865                     operands_[0]->shape_.dimensions().rbegin());
   1866 }
   1867 
   1868 void HloInstruction::RemoveUser(HloInstruction* user) {
   1869   auto set_it = user_set_.find(user);
   1870   CHECK(set_it != user_set_.end());
   1871   user_set_.erase(set_it);
   1872   // This is linear in the number of the users, but a vector provides a stable
   1873   // iteration order and much faster traversal.
   1874   auto vec_it = std::find(users_.begin(), users_.end(), user);
   1875   CHECK(vec_it != users_.end());
   1876   users_.erase(vec_it);
   1877 }
   1878 
   1879 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
   1880                                       HloInstruction* new_producer) {
   1881   TF_RET_CHECK(
   1882       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
   1883       << "this shape: " << ShapeUtil::HumanString(shape())
   1884       << ", replacement shape: "
   1885       << ShapeUtil::HumanString(new_producer->shape());
   1886 
   1887   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
   1888           << " with " << new_producer->name();
   1889 
   1890   RemoveUser(user);
   1891 
   1892   TF_RET_CHECK(
   1893       std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
   1894   std::replace(user->operands_.begin(), user->operands_.end(), this,
   1895                new_producer);
   1896   new_producer->AddUser(user);
   1897   return Status::OK();
   1898 }
   1899 
   1900 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
   1901                                           HloInstruction* new_operand) {
   1902   TF_RET_CHECK(operand_num >= 0);
   1903   TF_RET_CHECK(operand_num < operand_count());
   1904   HloInstruction* old_operand = mutable_operand(operand_num);
   1905   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
   1906                                                         new_operand->shape()))
   1907       << old_operand->shape().ShortDebugString() << " is not compatible with "
   1908       << new_operand->shape().ShortDebugString();
   1909   operands_[operand_num] = new_operand;
   1910 
   1911   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
   1912           << new_operand->name() << ", was " << old_operand->name();
   1913 
   1914   if (std::find(operands_.begin(), operands_.end(), old_operand) ==
   1915       operands_.end()) {
   1916     old_operand->RemoveUser(this);
   1917   }
   1918   new_operand->AddUser(this);
   1919   return Status::OK();
   1920 }
   1921 
   1922 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
   1923   bool new_producer_is_user = false;
   1924   for (HloInstruction* user : users()) {
   1925     if (user == new_producer) {
   1926       // It's possible that new_producer is a user of this instruction as might
   1927       // be the case when replacing an instruction with a kCopy of itself. In
   1928       // this case, don't do the replacement to avoid creating a cycle in the
   1929       // graph. new_producer remains the only user of this instruction.
   1930       new_producer_is_user = true;
   1931     } else {
   1932       std::replace(user->operands_.begin(), user->operands_.end(), this,
   1933                    new_producer);
   1934       new_producer->AddUser(user);
   1935     }
   1936   }
   1937   users_.clear();
   1938   user_set_.clear();
   1939   if (new_producer_is_user) {
   1940     AddUser(new_producer);
   1941   }
   1942   if (parent_ && parent_->root_instruction() == this) {
   1943     parent_->set_root_instruction(new_producer);
   1944   }
   1945 
   1946   return Status::OK();
   1947 }
   1948 
   1949 void HloInstruction::DetachFromOperands() {
   1950   VLOG(3) << "DetachFromOperands:\n  " << ToString();
   1951   CHECK_EQ(0, user_count());
   1952   // An instruction may be repeated as an operand. To avoid calling RemoveUser
   1953   // twice on the same operand, keep a set of already detached operands.
   1954   std::set<HloInstruction*> detached_operands;
   1955   for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
   1956     HloInstruction* operand = operands_[operand_num];
   1957     if (!ContainsKey(detached_operands, operand)) {
   1958       operand->RemoveUser(this);
   1959       detached_operands.insert(operand);
   1960     }
   1961     operands_[operand_num] = nullptr;
   1962   }
   1963 }
   1964 
   1965 HloComputation* HloInstruction::to_apply() const {
   1966   switch (opcode_) {
   1967     case HloOpcode::kCall:
   1968     case HloOpcode::kMap:
   1969     case HloOpcode::kReduceWindow:
   1970     case HloOpcode::kReduce:
   1971       CHECK_EQ(called_computations_.size(), 1);
   1972       return called_computations_[0];
   1973     default:
   1974       LOG(FATAL) << "Invalid opcode for to_apply(): "
   1975                  << HloOpcodeString(opcode());
   1976   }
   1977 }
   1978 
   1979 void HloInstruction::set_to_apply(HloComputation* computation) {
   1980   // Don't allow changing the computation for fused instructions so we don't
   1981   // have to recompute called_instructions for the entire fusion instruction.
   1982   CHECK(!IsFused());
   1983   switch (opcode_) {
   1984     case HloOpcode::kCall:
   1985     case HloOpcode::kMap:
   1986     case HloOpcode::kReduceWindow:
   1987     case HloOpcode::kReduce:
   1988       CHECK_EQ(called_computations_.size(), 1);
   1989       called_computations_[0] = computation;
   1990       break;
   1991     default:
   1992       LOG(FATAL) << "Invalid opcode for to_apply(): "
   1993                  << HloOpcodeString(opcode());
   1994   }
   1995 }
   1996 
   1997 const string& HloInstruction::custom_call_target() const {
   1998   CHECK_EQ(opcode_, HloOpcode::kCustomCall);
   1999   return custom_call_target_;
   2000 }
   2001 
   2002 const string& HloInstruction::outfeed_config() const {
   2003   CHECK_EQ(opcode_, HloOpcode::kOutfeed);
   2004   return outfeed_config_;
   2005 }
   2006 
   2007 HloComputation* HloInstruction::while_condition() const {
   2008   CHECK_EQ(HloOpcode::kWhile, opcode_);
   2009   return called_computations_[kConditionComputationIndex];
   2010 }
   2011 
   2012 HloComputation* HloInstruction::while_body() const {
   2013   CHECK_EQ(HloOpcode::kWhile, opcode_);
   2014   return called_computations_[kBodyComputationIndex];
   2015 }
   2016 
   2017 void HloInstruction::set_while_condition(HloComputation* computation) {
   2018   // Don't allow changing the computation for fused instructions so we don't
   2019   // have to recompute called_instructions for the entire fusion instruction.
   2020   CHECK(!IsFused());
   2021   CHECK_EQ(HloOpcode::kWhile, opcode_);
   2022   called_computations_[kConditionComputationIndex] = computation;
   2023 }
   2024 
   2025 void HloInstruction::set_while_body(HloComputation* computation) {
   2026   // Don't allow changing the computation for fused instructions so we don't
   2027   // have to recompute called_instructions for the entire fusion instruction.
   2028   CHECK(!IsFused());
   2029   CHECK_EQ(HloOpcode::kWhile, opcode_);
   2030   called_computations_[kBodyComputationIndex] = computation;
   2031 }
   2032 
   2033 HloComputation* HloInstruction::select() const {
   2034   CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
   2035   return called_computations_[kSelectComputationIndex];
   2036 }
   2037 
   2038 HloComputation* HloInstruction::scatter() const {
   2039   CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
   2040   return called_computations_[kScatterComputationIndex];
   2041 }
   2042 
   2043 void HloInstruction::set_select(HloComputation* computation) {
   2044   // Don't allow changing the computation for fused instructions so we don't
   2045   // have to recompute called_instructions for the entire fusion instruction.
   2046   CHECK(!IsFused());
   2047   CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
   2048   called_computations_[kSelectComputationIndex] = computation;
   2049 }
   2050 
   2051 void HloInstruction::set_scatter(HloComputation* computation) {
   2052   // Don't allow changing the computation for fused instructions so we don't
   2053   // have to recompute called_instructions for the entire fusion instruction.
   2054   CHECK(!IsFused());
   2055   CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
   2056   called_computations_[kScatterComputationIndex] = computation;
   2057 }
   2058 
   2059 HloComputation* HloInstruction::true_computation() const {
   2060   CHECK_EQ(HloOpcode::kConditional, opcode_);
   2061   return called_computations_[kTrueComputationIndex];
   2062 }
   2063 
   2064 HloComputation* HloInstruction::false_computation() const {
   2065   CHECK_EQ(HloOpcode::kConditional, opcode_);
   2066   return called_computations_[kFalseComputationIndex];
   2067 }
   2068 
   2069 void HloInstruction::set_true_computation(HloComputation* true_computation) {
   2070   // Don't allow changing the computation for fused instructions so we don't
   2071   // have to recompute called_instructions for the entire fusion instruction.
   2072   CHECK(!IsFused());
   2073   CHECK_EQ(HloOpcode::kConditional, opcode_);
   2074   called_computations_[kTrueComputationIndex] = true_computation;
   2075 }
   2076 
   2077 void HloInstruction::set_false_computation(HloComputation* false_computation) {
   2078   // Don't allow changing the computation for fused instructions so we don't
   2079   // have to recompute called_instructions for the entire fusion instruction.
   2080   CHECK(!IsFused());
   2081   CHECK_EQ(HloOpcode::kConditional, opcode_);
   2082   called_computations_[kFalseComputationIndex] = false_computation;
   2083 }
   2084 
   2085 string HloInstruction::SignatureString() const {
   2086   string operands =
   2087       Join(operands_, ", ", [](string* out, HloInstruction* operand) {
   2088         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
   2089       });
   2090   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
   2091 }
   2092 
   2093 namespace {
   2094 
   2095 string PrintName(const string& name, const HloPrintOptions& options) {
   2096   return StrCat(options.print_percent() ? "%" : "", name);
   2097 }
   2098 
   2099 }  // namespace
   2100 
   2101 string HloInstruction::ToString(const HloPrintOptions& options) const {
   2102   string result =
   2103       StrCat(PrintName(name(), options), " = ",
   2104              ShapeUtil::HumanStringWithLayout(shape()), " ",
   2105              HloOpcodeString(opcode()), "(", OperandsToString(options), ")");
   2106   for (const string& extra : ExtraAttributesToString(options)) {
   2107     StrAppend(&result, ", ", extra);
   2108   }
   2109   if (options.print_metadata() &&
   2110       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
   2111        !metadata_.source_file().empty())) {
   2112     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
   2113   }
   2114   return result;
   2115 }
   2116 
   2117 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
   2118   string operands;
   2119   if (opcode() == HloOpcode::kConstant) {
   2120     // For constants, show the actual value in place of an empty operand list.
   2121     if ((!ShapeUtil::IsTuple(shape()) &&
   2122          ShapeUtil::ElementsIn(shape()) <= 10) ||
   2123         options.print_large_constants()) {
   2124       // Literal::ToString emits multidimensional arrays over multiple
   2125       // lines. Compact this into one line by stripping out white space.
   2126       string tmp = literal().ToString();
   2127       std::replace(tmp.begin(), tmp.end(), '\n', ' ');
   2128       std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
   2129       bool first = true;
   2130       // Concatenate elements in "v" with spaces separating them, but ignoring
   2131       // empty entries.
   2132       for (const auto& s : v) {
   2133         if (s.empty()) {
   2134           continue;
   2135         }
   2136         StrAppend(&operands, (first ? "" : " "), s);
   2137         first = false;
   2138       }
   2139     } else {
   2140       // Do not show large constants or tuples.
   2141       operands = "{...}";
   2142     }
   2143   } else if (opcode() == HloOpcode::kParameter) {
   2144     StrAppend(&operands, parameter_number_);
   2145   } else {
   2146     tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
   2147     const int64 kMaxOperandsToShowIfCompact = 4;
   2148     if (options.compact_operands() &&
   2149         slice.size() > kMaxOperandsToShowIfCompact) {
   2150       slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
   2151     }
   2152     operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
   2153       std::vector<string> str;
   2154       if (options.print_operand_shape()) {
   2155         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
   2156       }
   2157       if (!options.compact_operands()) {
   2158         str.push_back(PrintName(operand->name(), options));
   2159       }
   2160       StrAppend(out, Join(str, " "));
   2161     });
   2162     const int64 remaining = operands_.size() - slice.size();
   2163     if (slice.size() != operands_.size()) {
   2164       StrAppend(&operands, ", ...(+", remaining, ")");
   2165     }
   2166   }
   2167   return operands;
   2168 }
   2169 
   2170 std::vector<string> HloInstruction::ExtraAttributesToString(
   2171     const HloPrintOptions& options) const {
   2172   std::vector<string> extra;
   2173   if (opcode() == HloOpcode::kFusion) {
   2174     extra.push_back(StrCat("kind=", xla::ToString(fusion_kind())));
   2175   }
   2176   if (CanHaveDimensionsField()) {
   2177     extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
   2178   }
   2179   if (window_ != nullptr && window_->dimensions_size() != 0) {
   2180     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
   2181   }
   2182   if (padding_config_ != nullptr) {
   2183     extra.push_back(
   2184         StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
   2185   }
   2186   if (opcode() == HloOpcode::kSlice) {
   2187     std::vector<string> bounds;
   2188     bounds.reserve(slice_starts_.size());
   2189     const bool omit_stride =
   2190         std::all_of(slice_strides_.begin(), slice_strides_.end(),
   2191                     [](int64 stride) { return stride == 1; });
   2192     for (int i = 0; i < slice_starts_.size(); ++i) {
   2193       string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
   2194       bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i],
   2195                               stride_str, "]"));
   2196     }
   2197     extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
   2198   }
   2199   if (opcode() == HloOpcode::kDynamicSlice) {
   2200     extra.push_back(
   2201         StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
   2202   }
   2203   if (opcode() == HloOpcode::kBatchNormTraining ||
   2204       opcode() == HloOpcode::kBatchNormInference ||
   2205       opcode() == HloOpcode::kBatchNormGrad) {
   2206     extra.push_back(StrCat("epsilon=", epsilon()));
   2207     extra.push_back(StrCat("feature_index=", feature_index()));
   2208   }
   2209 
   2210   if (convolution_dimension_numbers_ != nullptr) {
   2211     extra.push_back(ConvolutionDimensionNumbersToString());
   2212   }
   2213   if (dot_dimension_numbers_ != nullptr) {
   2214     extra.push_back(DotDimensionNumbersToString());
   2215   }
   2216   if (gather_dimension_numbers_ != nullptr) {
   2217     extra.push_back(GatherDimensionNumbersToString());
   2218     extra.push_back(
   2219         StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
   2220   }
   2221   if (opcode() == HloOpcode::kFft) {
   2222     extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
   2223     extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
   2224   }
   2225 
   2226   if (options.print_subcomputation_references()) {
   2227     if (opcode() == HloOpcode::kWhile) {
   2228       extra.push_back(
   2229           StrCat("condition=", PrintName(while_condition()->name(), options)));
   2230       extra.push_back(
   2231           StrCat("body=", PrintName(while_body()->name(), options)));
   2232     } else if (opcode() == HloOpcode::kSelectAndScatter) {
   2233       extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
   2234       extra.push_back(
   2235           StrCat("scatter=", PrintName(scatter()->name(), options)));
   2236     } else if (opcode() == HloOpcode::kConditional) {
   2237       extra.push_back(StrCat("true_computation=",
   2238                              PrintName(true_computation()->name(), options)));
   2239       extra.push_back(StrCat("false_computation=",
   2240                              PrintName(false_computation()->name(), options)));
   2241     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
   2242                opcode() == HloOpcode::kReduceWindow ||
   2243                opcode() == HloOpcode::kReduce) {
   2244       extra.push_back(
   2245           StrCat("to_apply=", PrintName(to_apply()->name(), options)));
   2246     } else if (!called_computations().empty()) {
   2247       extra.push_back(StrCat(
   2248           "calls=", Join(called_computations(), ", ",
   2249                          [&](string* out, const HloComputation* computation) {
   2250                            StrAppend(out,
   2251                                      PrintName(computation->name(), options));
   2252                          })));
   2253     }
   2254   }
   2255 
   2256   if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
   2257       opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
   2258     extra.push_back(StrCat("channel_id=", channel_id_));
   2259   }
   2260 
   2261   if (opcode() == HloOpcode::kGetTupleElement) {
   2262     extra.push_back(StrCat("index=", tuple_index()));
   2263   }
   2264   if (has_sharding()) {
   2265     extra.push_back(StrCat("sharding=", sharding().ToString()));
   2266   }
   2267   if (!control_predecessors_.empty()) {
   2268     extra.push_back(StrCat("control-predecessors={",
   2269                            Join(control_predecessors_, ", ",
   2270                                 [&](string* out, HloInstruction* pre) {
   2271                                   StrAppend(out,
   2272                                             PrintName(pre->name(), options));
   2273                                 }),
   2274                            "}"));
   2275   }
   2276   if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
   2277     extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
   2278   }
   2279   if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
   2280     extra.push_back(
   2281         StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
   2282   }
   2283   if (opcode() == HloOpcode::kRng) {
   2284     extra.push_back(
   2285         StrCat("distribution=", RandomDistributionToString(distribution_)));
   2286   }
   2287   if (opcode() == HloOpcode::kReducePrecision) {
   2288     extra.push_back(StrCat("exponent_bits=", exponent_bits_));
   2289     extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
   2290   }
   2291 
   2292   // By contract, we print the custom call target even if
   2293   // !options.print_subcomputation_references(), because the call target is not
   2294   // an HloComputation.
   2295   if (opcode() == HloOpcode::kCustomCall) {
   2296     extra.push_back(
   2297         StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
   2298   }
   2299   return extra;
   2300 }
   2301 
   2302 string HloInstruction::ToShortString() const {
   2303   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
   2304                 Join(operands_, ", ",
   2305                      [](string* out, HloInstruction* operand) {
   2306                        StrAppend(out, "%", operand->name());
   2307                      }),
   2308                 ")");
   2309 }
   2310 
   2311 HloInstructionProto HloInstruction::ToProto() const {
   2312   HloInstructionProto proto;
   2313   proto.set_name(name_);
   2314   proto.set_opcode(HloOpcodeString(opcode_));
   2315   *proto.mutable_shape() = shape_;
   2316   for (const HloInstruction* operand : operands_) {
   2317     *proto.add_operand_names() = operand->name();
   2318   }
   2319   for (const HloInstruction* control : control_predecessors_) {
   2320     *proto.add_control_predecessor_names() = control->name();
   2321   }
   2322 
   2323   *proto.mutable_metadata() = metadata_;
   2324   if (literal_ != nullptr) {
   2325     *proto.mutable_literal() = literal_->ToProto();
   2326   }
   2327   proto.set_parameter_number(parameter_number_);
   2328   if (opcode() == HloOpcode::kFusion) {
   2329     proto.set_fusion_kind(xla::ToString(fusion_kind()));
   2330     *proto.mutable_fused_instructions_computation() =
   2331         fused_instructions_computation()->ToProto();
   2332   } else {
   2333     for (const HloComputation* computation : called_computations_) {
   2334       *proto.add_called_computation_names() = computation->name();
   2335     }
   2336   }
   2337 
   2338   proto.set_tuple_index(tuple_index_);
   2339   for (int64 dimension : dimensions_) {
   2340     proto.add_dimensions(dimension);
   2341   }
   2342   if (window_ != nullptr) {
   2343     *proto.mutable_window() = *window_;
   2344   }
   2345   if (convolution_dimension_numbers_ != nullptr) {
   2346     *proto.mutable_convolution_dimension_numbers() =
   2347         *convolution_dimension_numbers_;
   2348   }
   2349   if (dot_dimension_numbers_ != nullptr) {
   2350     *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
   2351   }
   2352   if (gather_dimension_numbers_ != nullptr) {
   2353     *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
   2354   }
   2355   if (opcode() == HloOpcode::kGather) {
   2356     for (int64 bound : gather_window_bounds()) {
   2357       proto.add_gather_window_bounds(bound);
   2358     }
   2359   }
   2360   for (int i = 0; i < slice_starts_.size(); ++i) {
   2361     auto* slice_dimension = proto.add_slice_dimensions();
   2362     slice_dimension->set_start(slice_starts_[i]);
   2363     slice_dimension->set_limit(slice_limits_[i]);
   2364     slice_dimension->set_stride(slice_strides_[i]);
   2365   }
   2366   proto.set_exponent_bits(exponent_bits_);
   2367   proto.set_mantissa_bits(mantissa_bits_);
   2368   for (int64 slice_size : dynamic_slice_sizes_) {
   2369     proto.add_dynamic_slice_sizes(slice_size);
   2370   }
   2371   if (padding_config_ != nullptr) {
   2372     *proto.mutable_padding_config() = *padding_config_;
   2373   }
   2374   proto.set_outfeed_config(outfeed_config_);
   2375   if (opcode() == HloOpcode::kRng) {
   2376     proto.set_distribution(distribution_);
   2377   }
   2378   proto.set_epsilon(epsilon_);
   2379   proto.set_feature_index(feature_index_);
   2380   proto.set_channel_id(channel_id_);
   2381   proto.set_infeed_config(infeed_config_);
   2382   proto.set_custom_call_target(custom_call_target_);
   2383   *proto.mutable_outfeed_shape() = outfeed_shape_;
   2384   proto.set_fft_type(fft_type_);
   2385   for (int64 fft_len : fft_length_) {
   2386     proto.add_fft_length(fft_len);
   2387   }
   2388 
   2389   return proto;
   2390 }
   2391 
   2392 string HloInstruction::ToCategory() const {
   2393   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
   2394       opcode() == HloOpcode::kReshape) {
   2395     return "data formatting";
   2396   }
   2397 
   2398   if (opcode() == HloOpcode::kConvolution) {
   2399     string category = "convolution";
   2400     if (window_util::HasBaseDilation(window())) {
   2401       category += " base-dilated";
   2402     }
   2403     if (window_util::HasWindowDilation(window())) {
   2404       category += " window-dilated";
   2405     }
   2406     return category;
   2407   }
   2408 
   2409   // Give transpose-dot and backwards-conv fusions the categories "dot" and
   2410   // "convolution" so they match the categories of proper kDot and kConvolution
   2411   // ops.  These fusion categories are really just a way of expressing a
   2412   // particular kind of dot or conv, so they should have the same category as a
   2413   // vanilla dot/conv.
   2414   if (opcode() == HloOpcode::kFusion) {
   2415     switch (fusion_kind()) {
   2416       case FusionKind::kLoop:
   2417         return "loop fusion";
   2418       case FusionKind::kInput:
   2419         return "input fusion";
   2420       case FusionKind::kOutput:
   2421         return "output fusion";
   2422       case FusionKind::kTransposeDot:
   2423         return "dot";
   2424       case FusionKind::kCustom:
   2425         return "custom fusion";
   2426     }
   2427   }
   2428 
   2429   if (IsElementwise()) {
   2430     return "non-fusion elementwise";
   2431   }
   2432 
   2433   return HloOpcodeString(opcode());
   2434 }
   2435 
   2436 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
   2437 
   2438 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
   2439   trace_instruction_ = trace_instruction;
   2440 }
   2441 
   2442 string HloInstruction::TracingTag() const {
   2443   CHECK_EQ(HloOpcode::kTrace, opcode());
   2444   CHECK(literal_ != nullptr);
   2445   return literal_->GetR1U8AsString();
   2446 }
   2447 
   2448 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
   2449 
   2450 bool HloInstruction::IsFusable() const {
   2451   // Instructions which are traced should not be fused.
   2452   if (tracing()) {
   2453     return false;
   2454   }
   2455   // Some kinds of instructions don't make sense to fuse.
   2456   switch (opcode_) {
   2457     case HloOpcode::kParameter:
   2458       return false;
   2459     // Side effecting instrutions cannot be fused.
   2460     default:
   2461       return !HasSideEffect();
   2462   }
   2463 }
   2464 
   2465 HloComputation* HloInstruction::fused_instructions_computation() const {
   2466   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2467   CHECK(!called_computations_.empty());
   2468   auto* fused_instructions_computation = called_computations_.front();
   2469   CHECK(fused_instructions_computation->IsFusionComputation());
   2470   return fused_instructions_computation;
   2471 }
   2472 
   2473 HloInstruction* HloInstruction::fused_expression_root() const {
   2474   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2475   return fused_instructions_computation()->root_instruction();
   2476 }
   2477 
   2478 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
   2479   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2480   return fused_instructions_computation()->parameter_instruction(
   2481       parameter_number);
   2482 }
   2483 
   2484 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
   2485   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2486   return fused_instructions_computation()->parameter_instructions();
   2487 }
   2488 
   2489 const tensorflow::gtl::iterator_range<UnwrappingIterator<
   2490     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
   2491 HloInstruction::fused_instructions() const {
   2492   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2493   const HloComputation* subcomp = fused_instructions_computation();
   2494   return subcomp->instructions();
   2495 }
   2496 
   2497 const tensorflow::gtl::iterator_range<
   2498     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
   2499 HloInstruction::fused_instructions() {
   2500   CHECK_EQ(opcode_, HloOpcode::kFusion);
   2501   return fused_instructions_computation()->instructions();
   2502 }
   2503 
   2504 int64 HloInstruction::fused_instruction_count() const {
   2505   return fused_instructions_computation()->instruction_count();
   2506 }
   2507 
   2508 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
   2509     : unique_id_(-1),
   2510       opcode_(opcode),
   2511       shape_(shape),
   2512       name_(HloOpcodeString(opcode)) {
   2513   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
   2514 }
   2515 
   2516 template <typename HloInstructionPtr>
   2517 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
   2518   switch (opcode_) {
   2519     case HloOpcode::kAbs:
   2520       return visitor->HandleAbs(this);
   2521     case HloOpcode::kAtan2:
   2522       return visitor->HandleAtan2(this);
   2523     case HloOpcode::kRoundNearestAfz:
   2524       return visitor->HandleRound(this);
   2525     case HloOpcode::kBatchNormTraining:
   2526       return visitor->HandleBatchNormTraining(this);
   2527     case HloOpcode::kBatchNormInference:
   2528       return visitor->HandleBatchNormInference(this);
   2529     case HloOpcode::kBatchNormGrad:
   2530       return visitor->HandleBatchNormGrad(this);
   2531     case HloOpcode::kSign:
   2532       return visitor->HandleSign(this);
   2533     case HloOpcode::kConstant:
   2534       return visitor->HandleConstant(this);
   2535     case HloOpcode::kGetTupleElement:
   2536       return visitor->HandleGetTupleElement(this);
   2537     case HloOpcode::kParameter:
   2538       return visitor->HandleParameter(this);
   2539     case HloOpcode::kEq:
   2540     case HloOpcode::kGe:
   2541     case HloOpcode::kGt:
   2542     case HloOpcode::kLe:
   2543     case HloOpcode::kLt:
   2544     case HloOpcode::kNe:
   2545       return visitor->HandleCompare(this);
   2546     case HloOpcode::kComplex:
   2547       return visitor->HandleComplex(this);
   2548     case HloOpcode::kAdd:
   2549       return visitor->HandleAdd(this);
   2550     case HloOpcode::kDivide:
   2551       return visitor->HandleDivide(this);
   2552     case HloOpcode::kSubtract:
   2553       return visitor->HandleSubtract(this);
   2554     case HloOpcode::kMaximum:
   2555       return visitor->HandleMaximum(this);
   2556     case HloOpcode::kMinimum:
   2557       return visitor->HandleMinimum(this);
   2558     case HloOpcode::kAnd:
   2559       return visitor->HandleAnd(this);
   2560     case HloOpcode::kOr:
   2561       return visitor->HandleOr(this);
   2562     case HloOpcode::kShiftLeft:
   2563       return visitor->HandleShiftLeft(this);
   2564     case HloOpcode::kShiftRightArithmetic:
   2565       return visitor->HandleShiftRightArithmetic(this);
   2566     case HloOpcode::kShiftRightLogical:
   2567       return visitor->HandleShiftRightLogical(this);
   2568     case HloOpcode::kConcatenate:
   2569       return visitor->HandleConcatenate(this);
   2570     case HloOpcode::kConvert:
   2571       return visitor->HandleConvert(this);
   2572     case HloOpcode::kBitcastConvert:
   2573       return visitor->HandleBitcastConvert(this);
   2574     case HloOpcode::kCopy:
   2575       return visitor->HandleCopy(this);
   2576     case HloOpcode::kMultiply:
   2577       return visitor->HandleMultiply(this);
   2578     case HloOpcode::kDot:
   2579       return visitor->HandleDot(this);
   2580     case HloOpcode::kPower:
   2581       return visitor->HandlePower(this);
   2582     case HloOpcode::kRemainder:
   2583       return visitor->HandleRemainder(this);
   2584     case HloOpcode::kSelect:
   2585       return visitor->HandleSelect(this);
   2586     case HloOpcode::kConvolution:
   2587       return visitor->HandleConvolution(this);
   2588     case HloOpcode::kFft:
   2589       return visitor->HandleFft(this);
   2590     case HloOpcode::kCrossReplicaSum:
   2591       return visitor->HandleCrossReplicaSum(this);
   2592     case HloOpcode::kTuple:
   2593       return visitor->HandleTuple(this);
   2594     case HloOpcode::kMap:
   2595       return visitor->HandleMap(this);
   2596     case HloOpcode::kClamp:
   2597       return visitor->HandleClamp(this);
   2598     case HloOpcode::kReduce:
   2599       return visitor->HandleReduce(this);
   2600     case HloOpcode::kReduceWindow:
   2601       return visitor->HandleReduceWindow(this);
   2602     case HloOpcode::kSelectAndScatter:
   2603       return visitor->HandleSelectAndScatter(this);
   2604     case HloOpcode::kNegate:
   2605       return visitor->HandleNegate(this);
   2606     case HloOpcode::kExp:
   2607       return visitor->HandleExp(this);
   2608     case HloOpcode::kFloor:
   2609       return visitor->HandleFloor(this);
   2610     case HloOpcode::kCeil:
   2611       return visitor->HandleCeil(this);
   2612     case HloOpcode::kLog:
   2613       return visitor->HandleLog(this);
   2614     case HloOpcode::kTanh:
   2615       return visitor->HandleTanh(this);
   2616     case HloOpcode::kCos:
   2617       return visitor->HandleCos(this);
   2618     case HloOpcode::kSin:
   2619       return visitor->HandleSin(this);
   2620     case HloOpcode::kReal:
   2621       return visitor->HandleReal(this);
   2622     case HloOpcode::kImag:
   2623       return visitor->HandleImag(this);
   2624     case HloOpcode::kIsFinite:
   2625       return visitor->HandleIsFinite(this);
   2626     case HloOpcode::kNot:
   2627       return visitor->HandleNot(this);
   2628     case HloOpcode::kBitcast:
   2629       return visitor->HandleBitcast(this);
   2630     case HloOpcode::kBroadcast:
   2631       return visitor->HandleBroadcast(this);
   2632     case HloOpcode::kPad:
   2633       return visitor->HandlePad(this);
   2634     case HloOpcode::kReshape:
   2635       return visitor->HandleReshape(this);
   2636     case HloOpcode::kTranspose:
   2637       return visitor->HandleTranspose(this);
   2638     case HloOpcode::kReverse:
   2639       return visitor->HandleReverse(this);
   2640     case HloOpcode::kReducePrecision:
   2641       return visitor->HandleReducePrecision(this);
   2642     case HloOpcode::kSlice:
   2643       return visitor->HandleSlice(this);
   2644     case HloOpcode::kDynamicSlice:
   2645       return visitor->HandleDynamicSlice(this);
   2646     case HloOpcode::kDynamicUpdateSlice:
   2647       return visitor->HandleDynamicUpdateSlice(this);
   2648     case HloOpcode::kSort:
   2649       return visitor->HandleSort(this);
   2650     case HloOpcode::kInfeed:
   2651       return visitor->HandleInfeed(this);
   2652     case HloOpcode::kOutfeed:
   2653       return visitor->HandleOutfeed(this);
   2654     case HloOpcode::kHostCompute:
   2655       return visitor->HandleHostCompute(this);
   2656     case HloOpcode::kRng:
   2657       return visitor->HandleRng(this);
   2658     case HloOpcode::kWhile:
   2659       return visitor->HandleWhile(this);
   2660     case HloOpcode::kFusion:
   2661       return visitor->HandleFusion(this);
   2662     case HloOpcode::kCall:
   2663       return visitor->HandleCall(this);
   2664     case HloOpcode::kConditional:
   2665       return visitor->HandleConditional(this);
   2666     case HloOpcode::kCustomCall:
   2667       return visitor->HandleCustomCall(this);
   2668     case HloOpcode::kRecv:
   2669       return visitor->HandleRecv(this);
   2670     case HloOpcode::kRecvDone:
   2671       return visitor->HandleRecvDone(this);
   2672     case HloOpcode::kSend:
   2673       return visitor->HandleSend(this);
   2674     case HloOpcode::kSendDone:
   2675       return visitor->HandleSendDone(this);
   2676     case HloOpcode::kGather:
   2677       return visitor->HandleGather(this);
   2678 
   2679     // These opcodes are not handled here.
   2680     case HloOpcode::kTrace:
   2681       break;
   2682   }
   2683   return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s",
   2684                        HloOpcodeString(opcode_).c_str());
   2685 }
   2686 
   2687 // Explicit instantiations.
   2688 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
   2689 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
   2690 
   2691 using DFSStack =
   2692     tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
   2693 
   2694 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
   2695 // cycle was detected, and true otherwise.
   2696 template <typename Visitor>
   2697 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
   2698                          HloInstruction* child) {
   2699   CHECK(child != nullptr);
   2700   const int id = child->unique_id();
   2701   CHECK_GE(id, 0) << "instruction may not have a parent computation";
   2702   switch (visitor->GetVisitState(id)) {
   2703     case Visitor::kVisiting:
   2704       return false;
   2705 
   2706     case Visitor::kVisited:
   2707       // Nothing to do
   2708       return true;
   2709 
   2710     case Visitor::kNotVisited:
   2711       dfs_stack->push_back(std::make_pair(id, child));
   2712       return true;
   2713   }
   2714 }
   2715 
   2716 using InternalCompareFunction =
   2717     std::function<bool(std::pair<int, const HloInstruction*>,
   2718                        std::pair<int, const HloInstruction*>)>;
   2719 template <typename Visitor>
   2720 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
   2721                            const InternalCompareFunction* operand_order,
   2722                            bool ignore_control_predecessors) {
   2723   visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
   2724 
   2725   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
   2726   //
   2727   // We need to keep track of both the id and the instruction because
   2728   // instructions can get deleted while they are on the stack, so we
   2729   // can't always use the (potentially dead) instruction object to grab
   2730   // its id.
   2731   DFSStack dfs_stack;
   2732   dfs_stack.emplace_back(root->unique_id(), root);
   2733 
   2734   do {
   2735     DCHECK(!dfs_stack.empty());
   2736 
   2737     int current_id = dfs_stack.back().first;
   2738     HloInstruction* current_node = dfs_stack.back().second;
   2739     CHECK_GE(current_id, 0) << current_id << ": " << current_node
   2740                             << ": instruction may not have parent computation";
   2741     typename Visitor::VisitState visit_state =
   2742         visitor->GetVisitState(current_id);
   2743     if (visit_state == Visitor::kVisited) {
   2744       dfs_stack.pop_back();
   2745       VLOG(3) << "Not visiting HLO %" << current_node->name()
   2746               << " as it was already visited.";
   2747       continue;
   2748     }
   2749 
   2750     if (visit_state == Visitor::kVisiting) {
   2751       dfs_stack.pop_back();
   2752 
   2753       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
   2754       VLOG(2) << "Visiting HLO %" << current_node->name();
   2755       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
   2756       visitor->SetVisitState(current_id, Visitor::kVisited);
   2757       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
   2758       continue;
   2759     }
   2760 
   2761     visitor->SetVisitState(current_id, Visitor::kVisiting);
   2762 
   2763     const size_t old_dfs_stack_size = dfs_stack.size();
   2764     for (HloInstruction* child : current_node->operands()) {
   2765       if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
   2766         return FailedPrecondition(
   2767             "A cycle is detected while visiting instruction %s",
   2768             current_node->ToString().c_str());
   2769       }
   2770     }
   2771 
   2772     if (!ignore_control_predecessors) {
   2773       for (HloInstruction* child : current_node->control_predecessors()) {
   2774         if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
   2775           return FailedPrecondition(
   2776               "A cycle is detected while visiting instruction %s",
   2777               current_node->ToString().c_str());
   2778         }
   2779       }
   2780     }
   2781 
   2782     if (operand_order != nullptr) {
   2783       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
   2784                 *operand_order);
   2785     }
   2786 
   2787     // This makes the traversal order the same as what you'd expect
   2788     // out of a recursive algorithm.
   2789     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
   2790   } while (!dfs_stack.empty());
   2791 
   2792   return Status::OK();
   2793 }
   2794 
   2795 template <typename HloInstructionPtr>
   2796 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
   2797                               bool call_finish_visit,
   2798                               bool ignore_control_predecessors) {
   2799   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
   2800   TF_RETURN_IF_ERROR(
   2801       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
   2802   if (call_finish_visit) {
   2803     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
   2804   }
   2805   return Status::OK();
   2806 }
   2807 
   2808 // Explicit instantiations.
   2809 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
   2810 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
   2811 
   2812 Status HloInstruction::AcceptWithOperandOrder(
   2813     DfsHloVisitor* visitor, const CompareFunction& operand_order,
   2814     bool call_finish_visit) {
   2815   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
   2816   InternalCompareFunction func = [&operand_order](
   2817                                      std::pair<int, const HloInstruction*> a,
   2818                                      std::pair<int, const HloInstruction*> b) {
   2819     // Call the client's comparison function on the actual HloInstruction*
   2820     // objects (ignoring the internal ids we also have in our stack entries)
   2821     return operand_order(a.second, b.second);
   2822   };
   2823   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
   2824                                   /*ignore_control_predecessors=*/false));
   2825   if (call_finish_visit) {
   2826     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
   2827     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
   2828     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
   2829   }
   2830   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
   2831   return Status::OK();
   2832 }
   2833 
   2834 namespace {
   2835 
   2836 // Returns true if the given order is a topological sort of the instructions
   2837 // it contains.
   2838 bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) {
   2839   // Create a map from instruction to its position in 'order'.
   2840   std::unordered_map<const HloInstruction*, int> order_position;
   2841   for (int i = 0; i < order.size(); i++) {
   2842     if (!order_position.insert({order[i], i}).second) {
   2843       // Instruction order[i] is duplicated in the order.
   2844       return false;
   2845     }
   2846   }
   2847   // Verify that the operand of each instruction in the order is also in the
   2848   // order *and* the operand's position is earlier (defs are before uses for
   2849   // all ops).
   2850   for (auto* instruction : order) {
   2851     for (auto* operand : instruction->operands()) {
   2852       if (!ContainsKey(order_position, operand) ||
   2853           order_position.at(operand) >= order_position.at(instruction)) {
   2854         return false;
   2855       }
   2856     }
   2857   }
   2858 
   2859   return true;
   2860 }
   2861 
   2862 }  // namespace
   2863 
   2864 Status HloInstruction::Accept(
   2865     const std::function<Status(HloInstruction*)>& visitor_func) {
   2866   FunctionVisitor visitor(visitor_func);
   2867   return this->Accept(&visitor);
   2868 }
   2869 
   2870 Status HloInstruction::Accept(
   2871     const std::function<Status(const HloInstruction*)>& visitor_func) const {
   2872   ConstFunctionVisitor visitor(visitor_func);
   2873   return this->Accept(&visitor);
   2874 }
   2875 
   2876 Status HloInstruction::AcceptOrdered(
   2877     DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
   2878   VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")";
   2879   TF_RET_CHECK(OrderIsTopologicalSort(order));
   2880 
   2881   // Compute the predecessors of this instruction.
   2882   std::unordered_set<const HloInstruction*> predecessors;
   2883   TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) {
   2884     predecessors.insert(instruction);
   2885     return Status::OK();
   2886   }));
   2887 
   2888   for (auto* const_instruction : order) {
   2889     if (!ContainsKey(predecessors, const_instruction)) {
   2890       // Instruction is not a predecessors of 'this'.
   2891       continue;
   2892     }
   2893 
   2894     // The visitor can mark instructions as visited to skip particular
   2895     // instructions.
   2896     if (visitor->DidVisit(*const_instruction)) {
   2897       VLOG(3) << "Not visiting HLO %" << const_instruction->name()
   2898               << " as it was already visited.";
   2899       continue;
   2900     }
   2901 
   2902     HloInstruction* instruction =
   2903         const_cast<HloInstruction*>(const_instruction);
   2904 
   2905     TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
   2906     VLOG(2) << "Visiting HLO %" << instruction->name();
   2907     TF_RETURN_IF_ERROR(instruction->Visit(visitor));
   2908     visitor->SetVisited(*instruction);
   2909     TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
   2910   }
   2911 
   2912   return visitor->FinishVisit(this);
   2913 }
   2914 
   2915 const Shape& HloInstruction::outfeed_shape() const {
   2916   DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
   2917   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
   2918   return outfeed_shape_;
   2919 }
   2920 
   2921 const Shape& HloInstruction::shape() const {
   2922   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
   2923   return shape_;
   2924 }
   2925 
   2926 std::vector<int64> HloInstruction::OperandIndices(
   2927     const HloInstruction* operand) const {
   2928   std::vector<int64> result;
   2929   for (int64 i = 0; i < operand_count(); ++i) {
   2930     if (this->operand(i) == operand) {
   2931       result.push_back(i);
   2932     }
   2933   }
   2934   return result;
   2935 }
   2936 
   2937 bool HloInstruction::IsElementwiseBinary() const {
   2938   return IsElementwise() && operand_count() == 2;
   2939 }
   2940 
   2941 bool HloInstruction::IsElementwise() const {
   2942   switch (opcode_) {
   2943     // Nullary elementwise operations.
   2944     case HloOpcode::kConstant:
   2945       return true;
   2946 
   2947     // Unary elementwise operations.
   2948     case HloOpcode::kAbs:
   2949     case HloOpcode::kRoundNearestAfz:
   2950     case HloOpcode::kCeil:
   2951     case HloOpcode::kConvert:
   2952     case HloOpcode::kBitcastConvert:
   2953     case HloOpcode::kCopy:
   2954     case HloOpcode::kCos:
   2955     case HloOpcode::kExp:
   2956     case HloOpcode::kFloor:
   2957     case HloOpcode::kImag:
   2958     case HloOpcode::kIsFinite:
   2959     case HloOpcode::kLog:
   2960     case HloOpcode::kNot:
   2961     case HloOpcode::kNegate:
   2962     case HloOpcode::kReal:
   2963     case HloOpcode::kReducePrecision:
   2964     case HloOpcode::kSign:
   2965     case HloOpcode::kSin:
   2966     case HloOpcode::kTanh:
   2967       CHECK_EQ(1, operand_count());
   2968       return true;
   2969 
   2970     // Binary elementwise operations, the same as in IsElementwiseBinary().
   2971     case HloOpcode::kAdd:
   2972     case HloOpcode::kAtan2:
   2973     case HloOpcode::kComplex:
   2974     case HloOpcode::kDivide:
   2975     case HloOpcode::kEq:
   2976     case HloOpcode::kGe:
   2977     case HloOpcode::kGt:
   2978     case HloOpcode::kLe:
   2979     case HloOpcode::kLt:
   2980     case HloOpcode::kMaximum:
   2981     case HloOpcode::kMinimum:
   2982     case HloOpcode::kMultiply:
   2983     case HloOpcode::kNe:
   2984     case HloOpcode::kPower:
   2985     case HloOpcode::kRemainder:
   2986     case HloOpcode::kSubtract:
   2987     case HloOpcode::kAnd:
   2988     case HloOpcode::kOr:
   2989     case HloOpcode::kShiftLeft:
   2990     case HloOpcode::kShiftRightArithmetic:
   2991     case HloOpcode::kShiftRightLogical:
   2992       CHECK_EQ(2, operand_count());
   2993       return true;
   2994 
   2995     // Ternary elementwise operations.
   2996     case HloOpcode::kSelect:
   2997       return !ShapeUtil::IsTuple(shape_);
   2998     case HloOpcode::kClamp:
   2999       return true;
   3000 
   3001     // Other operations.
   3002     case HloOpcode::kRng:
   3003     case HloOpcode::kMap:
   3004       return true;
   3005     case HloOpcode::kFusion:
   3006       if (fusion_kind() != FusionKind::kLoop) {
   3007         return false;
   3008       }
   3009       for (auto* fused : fused_instructions()) {
   3010         if (fused->opcode() != HloOpcode::kParameter &&
   3011             !fused->IsElementwise()) {
   3012           return false;
   3013         }
   3014       }
   3015       return true;
   3016 
   3017     default:
   3018       return false;
   3019   }
   3020 }
   3021 
   3022 bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
   3023   CHECK(IsElementwise());
   3024   return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape());
   3025 }
   3026 
   3027 namespace {
   3028 bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
   3029                                        const HloInstruction* operand) {
   3030   std::vector<int64> operand_indices = instruction->OperandIndices(operand);
   3031   return std::all_of(
   3032       operand_indices.begin(), operand_indices.end(),
   3033       [instruction](int64 operand_index) {
   3034         return instruction->IsElementwiseOnOperand(operand_index);
   3035       });
   3036 }
   3037 }  // namespace
   3038 
   3039 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
   3040   // For all instructions other than kFusion, being elementwise on one of the
   3041   // operands is equivalent to being elementwise on all the operands.
   3042   if (opcode() != HloOpcode::kFusion) {
   3043     return IsElementwise();
   3044   }
   3045 
   3046   CHECK_EQ(HloOpcode::kFusion, opcode());
   3047   if (fusion_kind() != FusionKind::kLoop) {
   3048     return false;
   3049   }
   3050 
   3051   // A loop-fusion is elementwise on an operand if all operations (computed
   3052   // using BFS) between the operand and the fused root are elementwise.
   3053   std::deque<HloInstruction*> worklist;
   3054   std::unordered_set<const HloInstruction*> visited;
   3055   worklist.push_back(fused_parameter(operand_idx));
   3056   visited.insert(fused_parameter(operand_idx));
   3057   while (!worklist.empty()) {
   3058     HloInstruction* operand = worklist.front();
   3059     worklist.pop_front();
   3060     for (HloInstruction* user : operand->users()) {
   3061       CHECK_GE(user->unique_id(), 0);
   3062       if (ContainsKey(visited, user)) {
   3063         continue;
   3064       }
   3065       if (user->IsElementwise() ||
   3066           IsInstructionElementwiseOnOperand(user, operand)) {
   3067         worklist.push_back(user);
   3068         visited.insert(user);
   3069       } else {
   3070         return false;
   3071       }
   3072     }
   3073   }
   3074   return true;
   3075 }
   3076 
   3077 // A helper class for memoized, recursive computation of HloOpcode::kFusion
   3078 // in HloInstruction::OperandElementUse below.
   3079 class HloInstruction::FusionReusesParamElements {
   3080  public:
   3081   using UseKind = HloInstruction::UseKind;
   3082 
   3083   // We could rather iterate backwards through fused_instructions_ here, as it
   3084   // is in reverse postorder, and compute whether each fused instruction reuses
   3085   // the value of this parameter, which would save stack space but not allow us
   3086   // to finish early if we find a reuse.
   3087   static UseKind Compute(int64 i, const HloInstruction& hlo) {
   3088     tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
   3089     return ComputeInternal(i, hlo, &memoization_cache);
   3090   }
   3091 
   3092  private:
   3093   static UseKind ComputeInternal(
   3094       int64 i, const HloInstruction& hlo,
   3095       tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
   3096     if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
   3097       return UseKind::kUse;
   3098     }
   3099 
   3100     auto p = cache->emplace(&hlo, UseKind{});
   3101     auto value_it = p.first;
   3102     const bool key_is_new = p.second;
   3103 
   3104     if (key_is_new) {
   3105       for (int64 j = 0; j < hlo.operands_.size(); ++j) {
   3106         UseKind old_val = value_it->second;
   3107 
   3108         // The next operation invalidates iterators.
   3109         UseKind new_val =
   3110             Plus(old_val, std::min(hlo.OperandElementUse(j),
   3111                                    ComputeInternal(i, *hlo.operand(j), cache)));
   3112 
   3113         // Re-acquire the iterator. We could work harder to do this only if
   3114         // absolutely necessary, but this code is not hot enough to warrant
   3115         // that.
   3116         value_it = cache->find(&hlo);
   3117         value_it->second = new_val;
   3118       }
   3119     }
   3120     return value_it->second;
   3121   }
   3122 
   3123   // Fold operation for UseKinds.
   3124   static UseKind Plus(UseKind a, UseKind b) {
   3125     if (a == UseKind::kNoUse) {
   3126       return b;
   3127     } else if (b == UseKind::kNoUse) {
   3128       return a;
   3129     } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
   3130       return UseKind::kReuse;
   3131     } else if (a == UseKind::kUsePermutingElements ||
   3132                b == UseKind::kUsePermutingElements) {
   3133       return UseKind::kReuse;
   3134     } else {
   3135       CHECK(a == UseKind::kUse && b == UseKind::kUse);
   3136       return UseKind::kUse;
   3137     }
   3138   }
   3139 };
   3140 
   3141 HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
   3142   switch (opcode_) {
   3143     case HloOpcode::kBitcast:
   3144     case HloOpcode::kConcatenate:
   3145     case HloOpcode::kReshape:
   3146     case HloOpcode::kReverse:
   3147     case HloOpcode::kSlice:
   3148     case HloOpcode::kTranspose:
   3149       return UseKind::kUsePermutingElements;
   3150     case HloOpcode::kPad:
   3151     case HloOpcode::kReduce:
   3152       // Pad reuses the padding value but not the padded array elements.
   3153       // Reduce reuses the init value but not the operand array elements.
   3154       return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
   3155     case HloOpcode::kFusion:
   3156       // Uses the memoizing, recursive computation defined above.
   3157       return FusionReusesParamElements::Compute(i, *fused_expression_root());
   3158     case HloOpcode::kDot:
   3159       // Dot operations with inputs [A,B] * [B,1] do not re-use
   3160       // elements on their left operand.
   3161       // Dot operations with inputs [1,A] * [A,B] do not re-use
   3162       // elements on their right operand.
   3163       if (shape().dimensions_size() == 2) {
   3164         if ((i == 0 && shape().dimensions(1) == 1) ||
   3165             (i == 1 && shape().dimensions(0) == 1)) {
   3166           return UseKind::kUse;
   3167         }
   3168       }
   3169       return UseKind::kReuse;
   3170     case HloOpcode::kDynamicUpdateSlice:
   3171       // Dynamic-update-slice reuses only operand 2 (start_indices).
   3172       if (i == 0 || i == 1) {
   3173         return UseKind::kUse;
   3174       }
   3175       return UseKind::kReuse;
   3176     default:
   3177       return IsElementwise() && !ImplicitlyBroadcastsOperand(i)
   3178                  ? UseKind::kUse
   3179                  : UseKind::kReuse;
   3180   }
   3181 }
   3182 
   3183 std::tuple<bool, std::vector<int64>, std::vector<int64>>
   3184 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
   3185   if (HloOpcode::kReshape != opcode_) {
   3186     return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
   3187   }
   3188   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
   3189                                                       shape_);
   3190 }
   3191 
   3192 string ToString(HloInstruction::FusionKind kind) {
   3193   switch (kind) {
   3194     case HloInstruction::FusionKind::kLoop:
   3195       return "kLoop";
   3196     case HloInstruction::FusionKind::kInput:
   3197       return "kInput";
   3198     case HloInstruction::FusionKind::kOutput:
   3199       return "kOutput";
   3200     case HloInstruction::FusionKind::kTransposeDot:
   3201       return "kTransposeDot";
   3202     case HloInstruction::FusionKind::kCustom:
   3203       return "kCustom";
   3204   }
   3205 }
   3206 
   3207 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
   3208     const string& kind_name) {
   3209   if (kind_name == "kLoop") {
   3210     return HloInstruction::FusionKind::kLoop;
   3211   }
   3212   if (kind_name == "kInput") {
   3213     return HloInstruction::FusionKind::kInput;
   3214   }
   3215   if (kind_name == "kOutput") {
   3216     return HloInstruction::FusionKind::kOutput;
   3217   }
   3218   if (kind_name == "kTransposeDot") {
   3219     return HloInstruction::FusionKind::kTransposeDot;
   3220   }
   3221   if (kind_name == "kCustom") {
   3222     return HloInstruction::FusionKind::kCustom;
   3223   }
   3224   return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
   3225 }
   3226 
   3227 string PaddingConfigToString(const PaddingConfig& padding) {
   3228   bool has_interior_padding =
   3229       std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
   3230                   [](const PaddingConfig::PaddingConfigDimension& dim) {
   3231                     return dim.interior_padding() != 0;
   3232                   });
   3233   return Join(
   3234       padding.dimensions(), "x",
   3235       [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
   3236         StrAppend(
   3237             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
   3238             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
   3239       });
   3240 }
   3241 
   3242 string OpMetadataToString(const OpMetadata& metadata) {
   3243   std::vector<string> result;
   3244   if (!metadata.op_type().empty()) {
   3245     result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
   3246   }
   3247   if (!metadata.op_name().empty()) {
   3248     result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
   3249   }
   3250   if (!metadata.source_file().empty()) {
   3251     result.push_back(
   3252         StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
   3253   }
   3254   if (metadata.source_line() != 0) {
   3255     result.push_back(StrCat("source_line=", metadata.source_line()));
   3256   }
   3257   return Join(result, " ");
   3258 }
   3259 
   3260 string RandomDistributionToString(const RandomDistribution& distribution) {
   3261   return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
   3262 }
   3263 
   3264 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
   3265   static std::unordered_map<string, RandomDistribution>* map = [] {
   3266     static auto* map = new std::unordered_map<string, RandomDistribution>;
   3267     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
   3268       if (RandomDistribution_IsValid(i)) {
   3269         auto value = static_cast<RandomDistribution>(i);
   3270         (*map)[RandomDistributionToString(value)] = value;
   3271       }
   3272     }
   3273     return map;
   3274   }();
   3275   auto found = map->find(tensorflow::str_util::Lowercase(name));
   3276   if (found == map->end()) {
   3277     return InvalidArgument("Unknown distribution");
   3278   }
   3279   return found->second;
   3280 }
   3281 
   3282 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
   3283   return os << ToString(kind);
   3284 }
   3285 
   3286 string HloInstruction::ConvolutionDimensionNumbersToString() const {
   3287   string result;
   3288   if (convolution_dimension_numbers_ == nullptr) {
   3289     return result;
   3290   }
   3291   const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
   3292   // Show the given dimension labels in order of major to minor based on the
   3293   // shape's layout.
   3294   const auto append_dims = [&](const std::vector<string>& dims,
   3295                                const Shape& shape) {
   3296     CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
   3297     StrAppend(&result, Join(dims, ""));
   3298   };
   3299 
   3300   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
   3301   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
   3302   std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
   3303   lhs_dims[dnums.input_batch_dimension()] = 'b';
   3304   lhs_dims[dnums.input_feature_dimension()] = 'f';
   3305   for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
   3306     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
   3307   }
   3308 
   3309   std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
   3310   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
   3311   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
   3312   for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
   3313     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
   3314   }
   3315 
   3316   std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
   3317   output_dims[dnums.output_batch_dimension()] = 'b';
   3318   output_dims[dnums.output_feature_dimension()] = 'f';
   3319   for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
   3320     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
   3321   }
   3322 
   3323   result += "dim_labels=";
   3324   append_dims(lhs_dims, operand(0)->shape());
   3325   result += "_";
   3326   append_dims(rhs_dims, operand(1)->shape());
   3327   result += "->";
   3328 
   3329   // A convolution can be represented as a kConvolution HLO or as a CustomCall
   3330   // that returns a tuple, the first element of which is the result of the
   3331   // convolution.
   3332   Shape this_shape =
   3333       ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
   3334   append_dims(output_dims, this_shape);
   3335   return result;
   3336 }
   3337 
   3338 string HloInstruction::DotDimensionNumbersToString() const {
   3339   std::vector<string> result;
   3340   if (dot_dimension_numbers_ == nullptr) {
   3341     return "";
   3342   }
   3343   const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
   3344   if (!dnums.lhs_batch_dimensions().empty()) {
   3345     result.push_back(StrCat("lhs_batch_dims={",
   3346                             Join(dnums.lhs_batch_dimensions(), ","), "}"));
   3347   }
   3348   result.push_back(StrCat("lhs_contracting_dims={",
   3349                           Join(dnums.lhs_contracting_dimensions(), ","), "}"));
   3350 
   3351   if (!dnums.rhs_batch_dimensions().empty()) {
   3352     result.push_back(StrCat("rhs_batch_dims={",
   3353                             Join(dnums.rhs_batch_dimensions(), ","), "}"));
   3354   }
   3355   result.push_back(StrCat("rhs_contracting_dims={",
   3356                           Join(dnums.rhs_contracting_dimensions(), ","), "}"));
   3357 
   3358   return Join(result, ", ");
   3359 }
   3360 
   3361 string HloInstruction::GatherDimensionNumbersToString() const {
   3362   CHECK_NE(gather_dimension_numbers_.get(), nullptr);
   3363   string output_window_dims =
   3364       StrCat("output_window_dims={",
   3365              Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
   3366   string elided_window_dims =
   3367       StrCat("elided_window_dims={",
   3368              Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
   3369   string gather_dims_to_operand_dims = StrCat(
   3370       "gather_dims_to_operand_dims={",
   3371       Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
   3372 
   3373   return Join<std::initializer_list<string>>(
   3374       {output_window_dims, elided_window_dims, gather_dims_to_operand_dims},
   3375       ", ");
   3376 }
   3377 
   3378 bool HloInstruction::CouldBeBitcast() const {
   3379   switch (opcode_) {
   3380     case HloOpcode::kTranspose:
   3381       return true;
   3382     case HloOpcode::kReshape:
   3383       return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
   3384     default:
   3385       return false;
   3386   }
   3387 }
   3388 
   3389 HloModule* HloInstruction::GetModule() const {
   3390   if (parent_) {
   3391     return parent_->parent();
   3392   }
   3393   return nullptr;
   3394 }
   3395 
   3396 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
   3397   string parent_str = parent() == nullptr ? "noparent" : parent()->name();
   3398   name_ = name_uniquer->GetUniqueName(name_);
   3399 }
   3400 
   3401 void HloInstruction::set_outer_dimension_partitions(
   3402     const std::vector<int64>& outer_dimension_partitions) {
   3403   outer_dimension_partitions_ = outer_dimension_partitions;
   3404 }
   3405 
   3406 void HloInstruction::RelayoutConstant(const Layout& new_layout,
   3407                                       const ShapeIndex& shape_index) {
   3408   CHECK_EQ(opcode(), HloOpcode::kConstant);
   3409   Shape* mutable_array_subshape =
   3410       ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
   3411   CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
   3412 
   3413   // Normally array_subshape will always have a layout, but this invariant is
   3414   // temporarily broken in LayoutAssignment::AssignLayouts.
   3415 
   3416   if (!mutable_array_subshape->has_layout() ||
   3417       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
   3418     literal_ = literal_->Relayout(new_layout, shape_index);
   3419     *mutable_array_subshape->mutable_layout() = new_layout;
   3420   }
   3421 }
   3422 
   3423 }  // namespace xla
   3424