Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implementation note:
     17 //
     18 // The general idea behind this pass is that we're converting from this:
     19 //   %param.A = OldShape
     20 //   %param.B = OldShape
     21 //   %reshape.A = NewShape reshape(%param.A)
     22 //   %reshape.B = NewShape reshape(%param.B)
     23 //   %instruction = NewShape instruction(%reshape.A, %reshape.B)
     24 // To this:
     25 //   %param.A = OldShape
     26 //   %param.B = OldShape
     27 //   %instruction = OldShape instruction(%param.A, %param.B)
     28 //   %reshape = NewShape reshape(%instruction)
     29 //
     30 // Where the instruction must be elementwise, and both reshapes and transposes
     31 // are moved.
     32 //
     33 // Most elementwise instructions support implicit broadcast of scalar operands,
     34 // but select is a special-case.  The signature is Select(Pred, A, B), and the
     35 // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or
     36 // transposes to a scalar should be cheap, we simply never move them.
     37 
     38 #include "tensorflow/compiler/xla/service/reshape_mover.h"
     39 
     40 #include <algorithm>
     41 #include "tensorflow/compiler/xla/literal_util.h"
     42 #include "tensorflow/compiler/xla/shape_util.h"
     43 #include "tensorflow/compiler/xla/status_macros.h"
     44 #include "tensorflow/compiler/xla/util.h"
     45 #include "tensorflow/core/lib/core/errors.h"
     46 
     47 namespace xla {
     48 
     49 namespace {
     50 
     51 bool IsReshapeOrTranspose(const HloInstruction* instruction) {
     52   return instruction->opcode() == HloOpcode::kReshape ||
     53          instruction->opcode() == HloOpcode::kTranspose;
     54 }
     55 
     56 // Returns true iff `instruction` can change its shape simply by adjusting
     57 // metadata.
     58 bool CanTriviallyChangeShape(const HloInstruction* instruction) {
     59   // NOTE: Technically a sequence of reshape(reshape(constant)) is also
     60   // trivially reshapable, so we might be tempted to simply recurse if
     61   // IsReshapeOrTranspose(instruction)==true.
     62   //
     63   // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially
     64   // reshapable if *all* instructions in the chain have user_count == 1. And
     65   // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we
     66   // rely on implicit scalar broadcast for scalars to be trivial. In addition,
     67   // these cases make it harder to maintain correctness of the UpdateOperand
     68   // logic below.
     69   //
     70   // So don't handle these chains, unless you update the tests and code to deal
     71   // with these properly. One idea is to add a pass immediately beforehand that
     72   // collapses trivial runs of reshapes / transposes.
     73 
     74   // Scalars can operate with any shape.
     75   if (ShapeUtil::IsScalar(instruction->shape())) {
     76     return true;
     77   }
     78 
     79   // A constant can trivially reshape the literal it holds.
     80   if (instruction->opcode() == HloOpcode::kConstant) {
     81     return true;
     82   }
     83 
     84   // An Rng instruction can be any shape as long as it has one user. Two copies
     85   // of the same Rng would be problematic if an Rng of a different shape would
     86   // produce random numbers in a different order.
     87   if (instruction->opcode() == HloOpcode::kRng &&
     88       instruction->user_count() == 1) {
     89     return true;
     90   }
     91   return false;
     92 }
     93 
     94 // Finds the first non-scalar operand of an instruction that is a non-trivial
     95 // reshape or transpose. Returns the operand if it is found or nullptr if not
     96 // found.
     97 HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand(
     98     const HloInstruction* hlo) {
     99   for (HloInstruction* operand : hlo->operands()) {
    100     if (!ShapeUtil::IsScalar(operand->shape()) &&
    101         IsReshapeOrTranspose(operand) &&
    102         !CanTriviallyChangeShape(operand->operand(0))) {
    103       VLOG(5) << "Found first non-scalar and non-trivial reshape operand of "
    104               << hlo->ToString(HloPrintOptions().set_print_metadata(false))
    105               << ":\n\t"
    106               << operand->ToString(HloPrintOptions().set_print_metadata(false));
    107       return operand;
    108     }
    109   }
    110   return nullptr;
    111 }
    112 
    113 // Returns whether `a` and `b` are equivalent for the purposes of this pass.
    114 bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
    115   if (a->opcode() != b->opcode() ||
    116       !ShapeUtil::SameDimensions(a->shape(), b->shape())) {
    117     return false;
    118   }
    119   switch (a->opcode()) {
    120     case HloOpcode::kTranspose:
    121       return a->dimensions() == b->dimensions();
    122     case HloOpcode::kReshape:
    123       return ShapeUtil::SameDimensions(a->operand(0)->shape(),
    124                                        b->operand(0)->shape());
    125     default:
    126       return false;
    127   }
    128 }
    129 
    130 // Returns true if all operands of `instruction` can easily change shape.
    131 // Operands can easily change shape if they are all reshapes/transposes to and
    132 // from the same shape. Additionally, operands like constant, rng, and any
    133 // scalar change shape with only an adjustment of metadata.
    134 bool AllOperandsHaveEasyShapeChanges(
    135     const HloInstruction* instruction,
    136     const HloInstruction* first_reshape_operand) {
    137   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
    138   VLOG(3) << "** Checking whether all operands have easy shape changes: "
    139           << instruction->ToString(print_no_metadata);
    140   // Check whether all operands:
    141   //    0. Have the same dimensions as the output -- if not, it may be
    142   //       implicitly broadcast, which can confound the movement's
    143   //       correctness.
    144   //
    145   // And one of the following:
    146   //    1. Are reshapes or transposes that have the same input and
    147   //       output shapes as all other reshaped or transposed operands.
    148   //     or
    149   //    2. Are one of kConstant, kRng, and scalars that can change shape
    150   //    trivially,
    151   for (const HloInstruction* operand : instruction->operands()) {
    152     if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
    153       VLOG(5) << "Operand shape differs from output shape; may be "
    154                  "implicitly broadcast, so preventing "
    155                  "movement\n\toperand: "
    156               << operand->ToString(print_no_metadata) << "\n\tinstruction: "
    157               << instruction->ToString(print_no_metadata);
    158       return false;
    159     }
    160 
    161     if (AreEquivalentReshapes(first_reshape_operand, operand)) {
    162       VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
    163               << first_reshape_operand->ToString(print_no_metadata)
    164               << "\n\toperand: " << operand->ToString(print_no_metadata);
    165       continue;
    166     }
    167 
    168     if (CanTriviallyChangeShape(operand)) {
    169       VLOG(5) << "Operand can trivially change shape: "
    170               << operand->ToString(print_no_metadata);
    171       continue;
    172     }
    173 
    174     // TODO(someone): Look into supporting general ops for the operands as
    175     // well.
    176     VLOG(5) << "Operand is neither equalivant to the first Reshape operand"
    177                "nor can trivially change shape: "
    178             << operand->ToString(print_no_metadata);
    179     return false;
    180   }
    181 
    182   VLOG(3) << "All operands have easy shape changes: "
    183           << instruction->ToString(print_no_metadata);
    184   return true;
    185 }
    186 
    187 // This function is called once we've decided to sink reshape/transpose operands
    188 // across an instruction. It returns an updated `operand` with a shape that
    189 // plays nicely with `new_operand_shape`; either it has the same shape (of the
    190 // correct type), or it is a scalar that may be implicitly broadcast.
    191 HloInstruction* UpdateOperand(HloComputation* computation,
    192                               const HloInstruction* first_reshape_operand,
    193                               const Shape& new_operand_shape,
    194                               HloInstruction* operand) {
    195   const PrimitiveType element_type = operand->shape().element_type();
    196   const Shape new_shape =
    197       ShapeUtil::ChangeElementType(new_operand_shape, element_type);
    198 
    199   switch (operand->opcode()) {
    200     case HloOpcode::kConstant: {
    201       if (first_reshape_operand->opcode() == HloOpcode::kReshape) {
    202         VLOG(5) << "Adding reshape to kConstant operand";
    203         return computation->AddInstruction(
    204             HloInstruction::CreateReshape(new_shape, operand));
    205       } else {
    206         CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose);
    207         VLOG(5) << "Adding transpose to kConstant operand";
    208         std::vector<int64> inverse_permutation =
    209             InversePermutation(first_reshape_operand->dimensions());
    210         return computation->AddInstruction(HloInstruction::CreateTranspose(
    211             new_shape, operand, inverse_permutation));
    212       }
    213     }
    214     case HloOpcode::kRng: {
    215       CHECK_EQ(operand->user_count(), 1);
    216       VLOG(5) << "Cloning kRng operand with new shape";
    217       return computation->AddInstruction(
    218           operand->CloneWithNewOperands(new_shape, operand->operands()));
    219     }
    220     case HloOpcode::kReshape:
    221     case HloOpcode::kTranspose: {
    222       VLOG(5) << "Using existing operand of kReshape or kTranspose";
    223       return operand->mutable_operand(0);
    224     }
    225     default:
    226       LOG(FATAL) << "Unexpected operand opcode during update: " << operand;
    227   }
    228 }
    229 
    230 // Try to sink any reshape or transpose operands of `instruction` across it. We
    231 // do so if `instruction` is elementwise and all operands are either equivalent
    232 // reshapes/transposes or are trivially reshapable.
    233 StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
    234                                          HloInstruction* instruction) {
    235   // Only perform sinks for live elementwise instructions with operands.
    236   const bool is_dead = instruction->user_count() == 0 &&
    237                        instruction != computation->root_instruction();
    238   if (!instruction->IsElementwise() || instruction->operands().empty() ||
    239       is_dead) {
    240     return false;
    241   }
    242 
    243   // Only perform sinks if there are any nontrivial reshape/transpose operands.
    244   const HloInstruction* first_reshape_operand =
    245       FirstNonScalarAndNonTrivialReshapeOperand(instruction);
    246   if (!first_reshape_operand) {
    247     return false;
    248   }
    249 
    250   // Only perform sinks if all operands can easily change shape.
    251   if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) {
    252     return false;
    253   }
    254 
    255   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
    256   // At this point we've decided to sink reshape/transpose operands.
    257   const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape();
    258   VLOG(3) << "** Sinking reshape or transpose: "
    259           << instruction->ToString(print_no_metadata)
    260           << "\n\tfirst reshape operand: "
    261           << first_reshape_operand->ToString(print_no_metadata)
    262           << "\n\tnew operand shape: "
    263           << ShapeUtil::HumanString(new_operand_shape);
    264 
    265   auto operands = instruction->operands();
    266   for (size_t i = 0; i < operands.size(); ++i) {
    267     // All scalar operands remain as-is, even if they're reshape or transpose,
    268     // to simplify handling wrt special scalar broadcast rules for ops like
    269     // Select. Scalar reshapes should be cheap anyways.
    270     if (ShapeUtil::IsScalar(operands[i]->shape())) {
    271       continue;
    272     }
    273     VLOG(3) << "Updating operand #" << i << ": "
    274             << operands[i]->ToString(print_no_metadata);
    275     operands[i] = UpdateOperand(computation, first_reshape_operand,
    276                                 new_operand_shape, operands[i]);
    277   }
    278   if (HloOpcode::kFusion == instruction->opcode()) {
    279     // Here we already know `instruction` is elementwise, and no operand is
    280     // implicit broadcast as if it were the operands would not have easy shape
    281     // changes, so all the fused instructions have the same dimensions.
    282     for (const auto& fused_instruction : instruction->fused_instructions()) {
    283       Shape* shape = fused_instruction->mutable_shape();
    284       *shape->mutable_dimensions() = new_operand_shape.dimensions();
    285       *shape->mutable_layout() = new_operand_shape.layout();
    286     }
    287   }
    288   HloInstruction* new_elementwise =
    289       computation->AddInstruction(instruction->CloneWithNewOperands(
    290           // `instruction` may change the element type, e.g., from
    291           //   operands[0] -> reshape -> convert (`instruction`)
    292           // to
    293           //   operands[0] -> convert' -> reshape'
    294           //
    295           // In this case, convert' should have the same element type as
    296           // `convert` and the same dimensions as operands[0].
    297           ShapeUtil::ChangeElementType(new_operand_shape,
    298                                        instruction->shape().element_type()),
    299           operands));
    300 
    301   std::unique_ptr<HloInstruction> new_reshape;
    302   switch (first_reshape_operand->opcode()) {
    303     case HloOpcode::kReshape:
    304       VLOG(3) << "Creating new reshape for new elementwise op: "
    305               << new_elementwise->ToString(print_no_metadata);
    306       new_reshape =
    307           HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
    308       break;
    309     case HloOpcode::kTranspose:
    310       new_reshape =
    311           HloInstruction::CreateTranspose(instruction->shape(), new_elementwise,
    312                                           first_reshape_operand->dimensions());
    313       break;
    314     default:
    315       LOG(FATAL) << "Bad opcode";
    316   }
    317   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
    318       instruction, std::move(new_reshape)));
    319   return true;
    320 }
    321 
    322 }  // namespace
    323 
    324 StatusOr<bool> ReshapeMover::Run(HloModule* module) {
    325   bool changed = false;
    326   VLOG(2) << "Pre ReshapeMover HLO:";
    327   XLA_VLOG_LINES(2, module->ToString());
    328   for (auto* comp : module->MakeNonfusionComputations()) {
    329     for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
    330       TF_ASSIGN_OR_RETURN(bool did_change,
    331                           TrySinkReshapeOrTranspose(comp, instruction));
    332       changed |= did_change;
    333     }
    334   }
    335   VLOG(2) << "Post ReshapeMover HLO:";
    336   XLA_VLOG_LINES(2, module->ToString());
    337   return changed;
    338 }
    339 
    340 }  // namespace xla
    341