Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/liveness_util.h"
     17 
     18 #include <algorithm>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 
     29 namespace xla {
     30 
     31 bool DoesNotUseOperandBuffer(const HloInstruction* operand,
     32                              const ShapeIndex& index,
     33                              const HloInstruction* user,
     34                              const TuplePointsToAnalysis& points_to_analysis) {
     35   CHECK(user->IsUserOf(operand))
     36       << "user: " << user->ToString() << " operand: " << operand->ToString();
     37   if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
     38     // GetTupleElement instructions only access the top-level buffer of their
     39     // operand.
     40     return true;
     41   } else if (user->opcode() == HloOpcode::kFusion &&
     42              user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
     43     // Find fusion parameter associated with 'operand'.
     44     auto it = std::find_if(
     45         user->fused_parameters().begin(), user->fused_parameters().end(),
     46         [=](HloInstruction* fused_param) {
     47           return user->operand(fused_param->parameter_number()) == operand;
     48         });
     49     CHECK(it != user->fused_parameters().end());
     50     // Iterate through all users of all buffer aliases of the buffer in the
     51     // points-to set of fusion parameter at 'index'.
     52     // Return false if any uses are detected at 'index', returns true otherwise.
     53     const LogicalBuffer* buffer =
     54         points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie();
     55     for (const BufferAlias& alias :
     56          points_to_analysis.GetBufferAliases(*buffer)) {
     57       for (HloInstruction* alias_user : alias.instruction()->users()) {
     58         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
     59                                     alias_user, points_to_analysis)) {
     60           continue;
     61         }
     62         // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
     63         return false;
     64       }
     65     }
     66     // Return true: found no uses of 'operand' at 'index' in 'user'.
     67     return true;
     68   }
     69   return false;
     70 }
     71 
     72 bool DoesNotUseOperandBuffer(const HloInstruction* operand,
     73                              const ShapeIndex& index,
     74                              const HloInstruction* user,
     75                              const HloDataflowAnalysis& dataflow) {
     76   CHECK(user->IsUserOf(operand))
     77       << "user: " << user->ToString() << " operand: " << operand->ToString();
     78   if (user->opcode() == HloOpcode::kFusion &&
     79       user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
     80     // Find fusion parameter associated with 'operand'.
     81     HloInstruction* fusion_param =
     82         user->fused_parameter(user->operand_index(operand));
     83     // Iterate through all users of all uses of the fusion parameter value.
     84     // Return false if any uses are detected, returns true otherwise.
     85     const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index);
     86     return value.uses().empty();
     87   } else {
     88     // Return false if no value at 'operand' and 'index' is used at 'user'.
     89     for (const HloValue* value :
     90          dataflow.GetValueSet(operand, index).values()) {
     91       for (const HloUse& use : value->uses()) {
     92         if (use.instruction == user) {
     93           return false;
     94         }
     95       }
     96     }
     97   }
     98 
     99   return true;
    100 }
    101 
    102 namespace {
    103 
    104 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
    105 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
    106 // where 'user' is a user of an alias of 'instruction' at 'index', and
    107 // 'operand_index' is the operand index at which the alias appears in the
    108 // operand list of 'user'.
    109 std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
    110     HloInstruction* instruction, const ShapeIndex& index,
    111     const TuplePointsToAnalysis& points_to_analysis) {
    112   std::vector<std::pair<HloInstruction*, int64>> uses;
    113   const PointsToSet::BufferList& points_to =
    114       points_to_analysis.GetPointsToSet(instruction).element(index);
    115   for (const LogicalBuffer* buffer : points_to) {
    116     for (const BufferAlias& alias :
    117          points_to_analysis.GetBufferAliases(*buffer)) {
    118       for (HloInstruction* alias_user : alias.instruction()->users()) {
    119         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
    120                                     alias_user, points_to_analysis)) {
    121           continue;
    122         }
    123         for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
    124           uses.emplace_back(alias_user, op_idx);
    125         }
    126       }
    127     }
    128   }
    129   return uses;
    130 }
    131 
    132 // Returns true if there is exactly one use of 'operand' at 'operand_index'
    133 // in 'fusion.fused_instructions', where the singleton use is the fused
    134 // root at operand index 'use_operand_index'. Returns false otherwise.
    135 //
    136 // REQUIRES: 'fusion' opcode is a kFusion instruction.
    137 bool HasUniqueFusedUseOfOperandAt(
    138     HloInstruction* operand, const ShapeIndex& operand_index,
    139     HloInstruction* fusion, const int64 use_operand_index,
    140     const TuplePointsToAnalysis& points_to_analysis) {
    141   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
    142   // Check that 'operand' is unique in the operand list of 'fusion'.
    143   if (fusion->OperandIndices(operand).size() > 1) {
    144     return false;
    145   }
    146   // Find fusion parameter associated with 'operand'.
    147   const auto& fused_params = fusion->fused_parameters();
    148   auto fused_param_it = std::find_if(
    149       fused_params.begin(), fused_params.end(),
    150       [&](HloInstruction* fused_param) {
    151         return fusion->operand(fused_param->parameter_number()) == operand;
    152       });
    153   if (fused_param_it == fused_params.end()) {
    154     return false;
    155   }
    156   auto* fused_param = *fused_param_it;
    157   // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
    158   auto fused_param_uses = GetAllUsesOfInstructionAtIndex(
    159       fused_param, operand_index, points_to_analysis);
    160   // Return true iff there is exactly one use of 'operand' at 'index', and
    161   // this singleton use is the fused root (at index in 'use_operand_indices').
    162   return fused_param_uses.size() == 1 &&
    163          fused_param_uses[0].first == fusion->fused_expression_root() &&
    164          fused_param_uses[0].second == use_operand_index;
    165 }
    166 
    167 }  // namespace
    168 
    169 // User and operand can share buffers iff both instructions emit the same shape
    170 // and layout, and 'user' meets one of the following qualifications:
    171 //
    172 // (1) Is element-wise. Or...
    173 // (2) Is a loop fusion instruction where the only use of 'operand' at 'index'
    174 //     in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
    175 //     at operand 0. Or...
    176 // (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion
    177 //     instruction where the only use of 'operand' at 'index' in the set
    178 //     'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or...
    179 // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
    180 //     0.
    181 //
    182 // (2) and (3) can only be determined if points-to analysis is available.
    183 bool CanShareOperandBufferWithUser(
    184     HloInstruction* operand, const ShapeIndex& operand_index,
    185     HloInstruction* user, const ShapeIndex& user_index,
    186     const TuplePointsToAnalysis& points_to_analysis) {
    187   CHECK(user->IsUserOf(operand))
    188       << "user: " << user->ToString() << " operand: " << operand->ToString();
    189   const Shape& operand_subshape =
    190       ShapeUtil::GetSubshape(operand->shape(), operand_index);
    191   const Shape& user_subshape =
    192       ShapeUtil::GetSubshape(user->shape(), user_index);
    193   // Check that operand and user emit the same shape and layout.
    194   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
    195     return false;
    196   }
    197   if (user->opcode() == HloOpcode::kFusion) {
    198     if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
    199         user->fused_expression_root()->opcode() ==
    200             HloOpcode::kDynamicUpdateSlice) {
    201       // Loop fusion with kDynamicUpdateSlice fused root.
    202       //
    203       // Returns true iff there is exactly one use of 'operand' at shape index
    204       // 'operand_index', and this singleton use is the fused root at operand
    205       // index 0.
    206       return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0,
    207                                           points_to_analysis);
    208     } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
    209                user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
    210       // Output fusion with kAdd fused root.
    211 
    212       // Check if one operand of kAdd fused root is either kDot, or nested
    213       // kFusion of kind kTransposeDot.
    214       auto* add = user->fused_expression_root();
    215       auto add_operand_it =
    216           std::find_if(add->operands().begin(), add->operands().end(),
    217                        [&](HloInstruction* operand) {
    218                          return operand->opcode() == HloOpcode::kConvolution ||
    219                                 operand->opcode() == HloOpcode::kDot ||
    220                                 (operand->opcode() == HloOpcode::kFusion &&
    221                                  operand->fusion_kind() ==
    222                                      HloInstruction::FusionKind::kTransposeDot);
    223                        });
    224       if (add_operand_it == add->operands().end()) {
    225         return false;
    226       }
    227       auto* matched_add_operand = *add_operand_it;
    228       // Calculate operand index of 'add' operand which was not matched above.
    229       const int64 other_add_operand_index =
    230           matched_add_operand == add->operand(0) ? 1 : 0;
    231       // Returns true iff there is exactly one use of 'operand' at shape index
    232       // 'operand_index', and this singleton use is the fused root (at operand
    233       // index 'other_add_operand_index').
    234       return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
    235                                           other_add_operand_index,
    236                                           points_to_analysis);
    237     }
    238   }
    239   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
    240       user->opcode() == HloOpcode::kWhile) {
    241     // We eliminated other users in BufferLiveness::live_range_strictly_before,
    242     // so here we just need to check that the use is at operand index 0.
    243     std::vector<int64> operand_indices = user->OperandIndices(operand);
    244     return operand_indices.size() == 1 && operand_indices[0] == 0;
    245   }
    246   if (user->opcode() == HloOpcode::kCall) {
    247     // TODO(b/62548313): Remove when buffer assignment is module scoped and
    248     // does not assign buffers to calls.
    249     // Find called computation parameter associated with 'operand'.
    250     const std::vector<int64> operand_indices = user->OperandIndices(operand);
    251     if (operand_indices.size() > 1) {
    252       return false;
    253     }
    254     CHECK_EQ(1, operand_indices.size());
    255     auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
    256     // Get all uses of 'operand' at 'index' in called computation.
    257     auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index,
    258                                                      points_to_analysis);
    259 
    260     // Return true iff:
    261     // *) There exists exactly one use of 'operand' in called computation.
    262     // *) The unique use is by the root instruction of called computation.
    263     //    (Note: we check the root of the called computation, because the
    264     //     root result buffer is required to alias with the Call result buffer).
    265     // *) The root instruction of the called computation is element-wise on
    266     //    'operand'.
    267     auto* callee_root = user->to_apply()->root_instruction();
    268     return param_uses.size() == 1 && param_uses[0].first == callee_root &&
    269            callee_root->IsElementwiseOnOperand(param_uses[0].second);
    270   }
    271   // Check if 'user' is element-wise.
    272   return user->IsElementwise();
    273 }
    274 
    275 bool CanShareOperandBufferWithUser(HloInstruction* operand,
    276                                    const ShapeIndex& operand_index,
    277                                    HloInstruction* user,
    278                                    const ShapeIndex& user_index,
    279                                    const HloDataflowAnalysis& dataflow) {
    280   CHECK(user->IsUserOf(operand))
    281       << "user: " << user->ToString() << " operand: " << operand->ToString();
    282   const Shape& operand_subshape =
    283       ShapeUtil::GetSubshape(operand->shape(), operand_index);
    284   const Shape& user_subshape =
    285       ShapeUtil::GetSubshape(user->shape(), user_index);
    286   // Check that operand and user emit the same shape and layout.
    287   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
    288     return false;
    289   }
    290 
    291   if (user->opcode() == HloOpcode::kFusion) {
    292     // Get the parameter associated with 'operand';
    293     HloInstruction* fusion_param =
    294         user->fused_parameter(user->operand_index(operand));
    295 
    296     const HloValue& value =
    297         dataflow.GetValueDefinedAt(fusion_param, operand_index);
    298     if (value.uses().size() != 1) {
    299       return false;
    300     }
    301     const HloUse& use = value.uses()[0];
    302 
    303     if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
    304         user->fused_expression_root()->opcode() ==
    305             HloOpcode::kDynamicUpdateSlice) {
    306       // Loop fusion with kDynamicUpdateSlice fused root.
    307       //
    308       // Returns true iff there is exactly one use of 'operand' at shape index
    309       // 'operand_index', and this singleton use is the fused root at operand
    310       // index 0.
    311       return use.instruction == user->fused_expression_root() &&
    312              use.operand_number == 0;
    313     } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
    314                user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
    315       // Output fusion with kAdd fused root.
    316 
    317       // Check if one operand of kAdd fused root is either kDot, or nested
    318       // kFusion of kind kTransposeDot.
    319       auto* add = user->fused_expression_root();
    320       auto add_operand_it =
    321           std::find_if(add->operands().begin(), add->operands().end(),
    322                        [&](HloInstruction* operand) {
    323                          return operand->opcode() == HloOpcode::kConvolution ||
    324                                 operand->opcode() == HloOpcode::kDot ||
    325                                 (operand->opcode() == HloOpcode::kFusion &&
    326                                  operand->fusion_kind() ==
    327                                      HloInstruction::FusionKind::kTransposeDot);
    328                        });
    329       if (add_operand_it == add->operands().end()) {
    330         return false;
    331       }
    332       auto* matched_add_operand = *add_operand_it;
    333       // Calculate operand index of 'add' operand which was not matched above.
    334       const int64 other_add_operand_index =
    335           matched_add_operand == add->operand(0) ? 1 : 0;
    336       // Returns true iff there is exactly one use of 'operand' at shape index
    337       // 'operand_index', and this singleton use is the fused root (at operand
    338       // index 'other_add_operand_index').
    339       return use.instruction == user->fused_expression_root() &&
    340              use.operand_number == other_add_operand_index;
    341     }
    342   }
    343   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
    344       user->opcode() == HloOpcode::kWhile) {
    345     // We eliminated other users in BufferLiveness::live_range_strictly_before,
    346     // so here we just need to check that the use is at operand index 0.
    347     std::vector<int64> operand_indices = user->OperandIndices(operand);
    348     return operand_indices.size() == 1 && operand_indices[0] == 0;
    349   }
    350   if (user->opcode() == HloOpcode::kCall) {
    351     // Get all uses of value defined by 'operand' at 'operand_index'.
    352     const auto& uses =
    353         dataflow.GetValueDefinedAt(operand, operand_index).uses();
    354     // Return true iff:
    355     // *) There exists two uses of 'operand'.
    356     // *) One use is by 'user' (caller).
    357     // *) One use is by root instruction of called computation (callee root).
    358     //    (Note: we check the root of the called computation, because the
    359     //     root result buffer is required to alias with the Call result buffer).
    360     // *) The root instruction of the called computation is element-wise on
    361     //    'operand'.
    362     const bool found_caller_use =
    363         std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
    364           return use.instruction == user;
    365         }) != uses.end();
    366     auto* callee_root = user->to_apply()->root_instruction();
    367     const bool found_elementwise_callee_use =
    368         std::find_if(
    369             uses.begin(), uses.end(), [callee_root](const HloUse& use) {
    370               return use.instruction == callee_root &&
    371                      callee_root->IsElementwiseOnOperand(use.operand_number);
    372             }) != uses.end();
    373     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
    374   }
    375   // Check if 'user' is element-wise.
    376   return user->IsElementwise();
    377 }
    378 
    379 }  // namespace xla
    380