Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
     17 
     18 #include <unordered_map>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/status_macros.h"
     25 #include "tensorflow/compiler/xla/util.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 
     28 namespace xla {
     29 namespace gpu {
     30 
     31 namespace {
     32 
     33 // TODO(b/33483676) Use an expression tree to specify computations to pattern
     34 // match for while transformations.
     35 
     36 // ExprTree is a simple recursive data structure used to express computation
     37 // patterns to match.
     38 //
     39 // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each
     40 // of type ExprTree). Operands can be added by specifying the index and
     41 // HloOpcode of the operand.
     42 //
     43 // For example, the following computation:
     44 //
     45 //            Parameter
     46 //               |
     47 //   Const  GetTupleElement
     48 //      \   /
     49 //       Add (root)
     50 //
     51 // Can be matched with the following expression tree:
     52 //
     53 //   ExprTree add(HloOpcode::kAdd,
     54 //                ExprTree(HloOpcode::kConstant),
     55 //                ExprTree(HloOpcode::kGetTupleElement,
     56 //                         tuple_index, ExprTree(HloOpcode::kParameter)));
     57 //
     58 // Match the ExprTree root against an Hlo graph:
     59 //
     60 //   ExprTree::TaggedInstructionMap tagged_instructions;
     61 //   TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(),
     62 //                                &tagged_instructions));
     63 //
     64 // Instructions that are "tagged" with a context-specific string will
     65 // be returned in 'tagged_instructions' for further processing (i.e. parsing
     66 // constants or recording the tuple_index).
     67 //
     68 class ExprTree {
     69  public:
     70   explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {}
     71   ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {}
     72   ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) {
     73     SetOperand(0, operand0);
     74   }
     75   ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0)
     76       : opcode_(opcode) {
     77     SetOperand(index0, operand0);
     78   }
     79   ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0,
     80            int64 index1, const ExprTree& operand1)
     81       : opcode_(opcode) {
     82     SetOperand(index0, operand0);
     83     SetOperand(index1, operand1);
     84   }
     85   ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0)
     86       : opcode_(opcode), tag_(tag) {
     87     SetOperand(0, operand0);
     88   }
     89   ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1)
     90       : opcode_(opcode) {
     91     SetOperand(0, operand0);
     92     SetOperand(1, operand1);
     93   }
     94 
     95   ExprTree(const ExprTree& to_copy) {
     96     opcode_ = to_copy.opcode_;
     97     tag_ = to_copy.tag_;
     98     if (to_copy.fused_root_tree_ != nullptr) {
     99       fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_));
    100     }
    101     for (auto& pair : to_copy.operands_) {
    102       CHECK(operands_.find(pair.first) == operands_.end());
    103       operands_.insert(std::make_pair(
    104           pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second))));
    105     }
    106   }
    107 
    108   void SetFusedRoot(const ExprTree& fused_root) {
    109     fused_root_tree_.reset(new ExprTree(fused_root));
    110   }
    111 
    112   typedef std::unordered_map<string, const HloInstruction*>
    113       TaggedInstructionMap;
    114 
    115   // Matches 'instruction' HloOpcode against 'opcode_'.
    116   // Recursively matches each operand in 'operands_'.
    117   // Recursively matches fused instructions starting at 'fused_root_tree_'
    118   // if 'opcode_ == kFusion'.
    119   // Returns OK status, and instructions in 'tagged_instructions' for each
    120   // matched ExprTree node with a non-empty 'tag_'.
    121   // Returns error message on failure.
    122   Status Match(const HloInstruction* instruction,
    123                TaggedInstructionMap* tagged_instructions) const {
    124     if (opcode_ != instruction->opcode()) {
    125       return InvalidArgument("got opcode %s, want %s",
    126                              HloOpcodeString(instruction->opcode()).c_str(),
    127                              HloOpcodeString(opcode_).c_str());
    128     }
    129 
    130     VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_;
    131     if (!tag_.empty()) {
    132       tagged_instructions->insert({tag_, instruction});
    133     }
    134 
    135     if (instruction->opcode() == HloOpcode::kFusion) {
    136       CHECK(fused_root_tree_ != nullptr);
    137       // Match fused instructions for this node starting a 'fused_root_tree'.
    138       TF_RETURN_IF_ERROR(fused_root_tree_->Match(
    139           instruction->fused_expression_root(), tagged_instructions));
    140     }
    141 
    142     // Match each operand in 'operands_'.
    143     for (auto& pair : operands_) {
    144       TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first),
    145                                             tagged_instructions));
    146     }
    147     return tensorflow::Status::OK();
    148   }
    149 
    150  private:
    151   void SetOperand(int64 index, const ExprTree& operand) {
    152     CHECK_EQ(0, operands_.count(index));
    153     operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand)));
    154   }
    155 
    156   HloOpcode opcode_;
    157   std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_;
    158   std::unique_ptr<ExprTree> fused_root_tree_;
    159   string tag_;
    160 };
    161 
    162 // MatcherBase is a base class that provides common functionality for
    163 // sub-classes which match specific target sub-computations (i.e. loop
    164 // induction variable initialization, comparison and update).
    165 class MatcherBase {
    166  public:
    167   MatcherBase() {}
    168   virtual ~MatcherBase() {}
    169 
    170   // Attempts to match each ExprTree in 'expr_trees_'.
    171   // Returns OK on the first successful match, error status otherwise.
    172   virtual tensorflow::Status Run() {
    173     Status status;
    174     for (const ExprTree& expr_tree : expr_trees_) {
    175       status = MatchExprTree(expr_tree);
    176       if (status.ok()) {
    177         return status;
    178       }
    179     }
    180     return status;
    181   }
    182 
    183   virtual Status MatchExprTree(const ExprTree& expr_tree) = 0;
    184 
    185   // Returns the constant value parsed form kConstant 'instruction'.
    186   // Returns error status otherwise.
    187   Status ParseConstInteger(const HloInstruction* instruction,
    188                            int64* const_value) const {
    189     CHECK_EQ(HloOpcode::kConstant, instruction->opcode());
    190     PrimitiveType element_type = instruction->shape().element_type();
    191     if (element_type != S32 && element_type != S64) {
    192       return InvalidArgument("Expected constant of integral type.");
    193     }
    194     const Literal& literal = instruction->literal();
    195     PrimitiveType type = literal.shape().element_type();
    196     if (type != S32 && type != S64) {
    197       return InvalidArgument("Must use S32 or S64 integral types.");
    198     }
    199     if (type == S32) {
    200       *const_value = static_cast<int64>(literal.GetFirstElement<int32>());
    201     } else if (type == S64) {
    202       *const_value = literal.GetFirstElement<int64>();
    203     }
    204     return tensorflow::Status::OK();
    205   }
    206 
    207   StatusOr<const HloInstruction*> GetTaggedInstruction(
    208       const string& tag,
    209       const ExprTree::TaggedInstructionMap& tagged_instructions) {
    210     auto it = tagged_instructions.find(tag);
    211     if (it == tagged_instructions.end()) {
    212       return InvalidArgument("Cound not find instruction for tag: %s",
    213                              tag.c_str());
    214     }
    215     return it->second;
    216   }
    217 
    218  protected:
    219   std::vector<ExprTree> expr_trees_;
    220 
    221  private:
    222   TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
    223 };
    224 
    225 // WhileConditionComputationMatcher attempts to match a target computation
    226 // pattern in the while condition sub-computation.
    227 // If the target pattern is matched, two pieces of information are extracted
    228 // from 'tagged' instructions returned by the matcher:
    229 //
    230 // *) 'tuple_index':
    231 //    *) The loop induction variable tuple_index from the GetTupleElement
    232 //       instruction of the matched computation.
    233 //    *) Used in subsequent matching passes of while init operand and body
    234 //       computations to select loop induction variable tuple element.
    235 //
    236 // *) 'loop_limit':
    237 //    *) The integral value from Constant root operand in matched computation.
    238 //    *) Used as the constant for the loop limit.
    239 //
    240 class WhileConditionComputationMatcher : public MatcherBase {
    241  public:
    242   explicit WhileConditionComputationMatcher(const HloComputation* computation)
    243       : computation_(computation) {
    244     expr_trees_.emplace_back(BuildCondExprTree());
    245   }
    246 
    247   int64 loop_limit() const { return loop_limit_; }
    248   int64 tuple_index() const { return tuple_index_; }
    249 
    250  private:
    251   // Builds expression tree for the following condition computation:
    252   //
    253   //     Const  Parameter
    254   //        \     /
    255   //         Fusion ------------> FusionParam FusionParam
    256   //                                  \          /
    257   //                                  GTE       /
    258   //                                    \      /
    259   //                                    LessThan (fused root)
    260   //
    261   ExprTree BuildCondExprTree() {
    262     // Build ExprTree for fused instructions.
    263     ExprTree fused_root(
    264         HloOpcode::kLt,
    265         ExprTree(HloOpcode::kGetTupleElement, "gte",
    266                  ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")),
    267         ExprTree(HloOpcode::kParameter));
    268 
    269     // Build top-level computation.
    270     ExprTree root(HloOpcode::kFusion,
    271                   ExprTree(HloOpcode::kConstant, "loop_limit"),
    272                   ExprTree(HloOpcode::kParameter, "param0"));
    273 
    274     root.SetFusedRoot(fused_root);
    275     return root;
    276   }
    277 
    278   Status MatchExprTree(const ExprTree& expr_tree) override {
    279     VLOG(2) << "MATCHING while condition";
    280     ExprTree::TaggedInstructionMap tagged_instructions;
    281     TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
    282                                        &tagged_instructions));
    283 
    284     // Get tagged GTE instruction and set 'tuple_index_'.
    285     TF_ASSIGN_OR_RETURN(const HloInstruction* gte,
    286                         GetTaggedInstruction("gte", tagged_instructions));
    287     tuple_index_ = gte->tuple_index();
    288 
    289     // Get tagged Constant instruction and parse 'loop_limit_'.
    290     TF_ASSIGN_OR_RETURN(
    291         const HloInstruction* const_hlo,
    292         GetTaggedInstruction("loop_limit", tagged_instructions));
    293     TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_));
    294 
    295     // Get tagged "param0" instruction, and check that it matches
    296     // 'computation_' parameter 0.
    297     TF_ASSIGN_OR_RETURN(const HloInstruction* param0,
    298                         GetTaggedInstruction("param0", tagged_instructions));
    299     if (param0 != computation_->parameter_instruction(0)) {
    300       return InvalidArgument("Unexpected Parameter0 instruction : %s",
    301                              param0->name().c_str());
    302     }
    303 
    304     // Get tagged 'gte.fusion_param.param0', find its associated fusion operand,
    305     // and compare it to 'computation_' parameter0.
    306     TF_ASSIGN_OR_RETURN(
    307         const HloInstruction* gte_fusion_param0,
    308         GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions));
    309     CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode());
    310     CHECK(gte_fusion_param0->IsFused());
    311     if (gte_fusion_param0->parent()->FusionInstruction()->operand(
    312             gte_fusion_param0->parameter_number()) !=
    313         computation_->parameter_instruction(0)) {
    314       return InvalidArgument("Could not match fusion param: %s",
    315                              gte_fusion_param0->name().c_str());
    316     }
    317 
    318     return tensorflow::Status::OK();
    319   }
    320 
    321   const HloComputation* computation_;
    322 
    323   int64 loop_limit_ = -1;
    324   int64 tuple_index_ = -1;
    325 
    326   TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher);
    327 };
    328 
    329 // WhileInitOperandMatcher matches a target computation pattern of the
    330 // while instructions 'init' operand, indexing the tuple at 'tuple_index'.
    331 // On success, parses constant 'loop_start' which represents the loop induction
    332 // variable start values, then returns OK.
    333 // Returns error status otherwise.
    334 class WhileInitOperandMatcher : public MatcherBase {
    335  public:
    336   WhileInitOperandMatcher(const HloInstruction* while_hlo,
    337                           const int64 tuple_index)
    338       : while_hlo_(while_hlo), tuple_index_(tuple_index) {
    339     expr_trees_.emplace_back(BuildInitExprTree());
    340   }
    341 
    342   int64 loop_start() const { return loop_start_; }
    343 
    344  private:
    345   // Builds expression tree for the following while init operand subcomputation:
    346   //
    347   //             Const
    348   //               |
    349   //             Copy
    350   //               |
    351   //             Tuple0
    352   //               |
    353   //             While
    354   //
    355   ExprTree BuildInitExprTree() {
    356     return ExprTree(
    357         HloOpcode::kWhile, "while",
    358         ExprTree(HloOpcode::kTuple, tuple_index_,
    359                  ExprTree(HloOpcode::kCopy,
    360                           ExprTree(HloOpcode::kConstant, "loop_start"))));
    361   }
    362 
    363   Status MatchExprTree(const ExprTree& expr_tree) override {
    364     VLOG(2) << "MATCHING while init";
    365     ExprTree::TaggedInstructionMap tagged_instructions;
    366     TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions));
    367 
    368     // Get tagged while instruction check against 'while_hlo_'.
    369     TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo,
    370                         GetTaggedInstruction("while", tagged_instructions));
    371     if (while_hlo != while_hlo_) {
    372       return InvalidArgument("Expected While for instruction : %s",
    373                              while_hlo->name().c_str());
    374     }
    375 
    376     // Get tagged Constant instruction and parse 'loop_start_'.
    377     TF_ASSIGN_OR_RETURN(
    378         const HloInstruction* const_hlo,
    379         GetTaggedInstruction("loop_start", tagged_instructions));
    380     TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_));
    381 
    382     return tensorflow::Status::OK();
    383   }
    384 
    385   const HloInstruction* while_hlo_;
    386   const int64 tuple_index_;
    387 
    388   int64 loop_start_ = -1;
    389 
    390   TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher);
    391 };
    392 
    393 // WhileBodyComputationMatcher matches a target computation pattern for
    394 // the loop induction variable update. Matching proceeds from the while body
    395 // computation root[tuple_index] to param[tuple_index], where 'tuple_index'
    396 // If the target pattern is matched, parses a constant which represents the
    397 // loop induction variable increment value, then returns status OK.
    398 // Returns error status otherwise.
    399 class WhileBodyComputationMatcher : public MatcherBase {
    400  public:
    401   WhileBodyComputationMatcher(const HloComputation* computation,
    402                               const int64 tuple_index)
    403       : computation_(computation), tuple_index_(tuple_index) {
    404     expr_trees_.emplace_back(BuildBodyExprTree(0, 1));
    405     expr_trees_.emplace_back(BuildBodyExprTree(1, 0));
    406   }
    407 
    408   int64 loop_increment() const { return loop_increment_; }
    409 
    410  private:
    411   // Builds expression tree for the following while body computation:
    412   //
    413   //
    414   //                               FusionParam FusionParam
    415   //                                     \      /
    416   //                  Const Param         \   GTE1
    417   //                     \  /              \  /
    418   //                    Fusion -----------> Add
    419   //                      |
    420   //                     Copy
    421   //                      |
    422   //                     Tuple0
    423   //
    424   ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) {
    425     // Build ExprTree for fused instructions.
    426     ExprTree gte1 =
    427         ExprTree(HloOpcode::kGetTupleElement, "gte",
    428                  ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0"));
    429     ExprTree fused_root(HloOpcode::kAdd, const_index,
    430                         ExprTree(HloOpcode::kParameter), gte_index, gte1);
    431 
    432     // Build fusion instruction (and set fused root).
    433     ExprTree fusion(HloOpcode::kFusion, 0,
    434                     ExprTree(HloOpcode::kConstant, "loop_increment"), 1,
    435                     ExprTree(HloOpcode::kParameter, "param0"));
    436     fusion.SetFusedRoot(fused_root);
    437 
    438     // Build top-level computation.
    439     ExprTree tuple0(HloOpcode::kTuple, tuple_index_,
    440                     ExprTree(HloOpcode::kCopy, fusion));
    441     return tuple0;
    442   }
    443 
    444   Status MatchExprTree(const ExprTree& expr_tree) override {
    445     VLOG(2) << "MATCHING while body";
    446     ExprTree::TaggedInstructionMap tagged_instructions;
    447     TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(),
    448                                        &tagged_instructions));
    449 
    450     for (const auto& pair : tagged_instructions) {
    451       const auto& tag = pair.first;
    452       const auto& inst = pair.second;
    453 
    454       if (tag == "gte" && inst->tuple_index() != tuple_index_) {
    455         // Check that the matched GTE instruction is at the 'tuple_index' we
    456         // matched in the while condition computation.
    457         return InvalidArgument("Unexpected tuple index instruction : %s",
    458                                inst->name().c_str());
    459       } else if (tag == "loop_increment") {
    460         // Parse the constant which represents the loop induction variable
    461         // increment value.
    462         TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
    463       } else if (tag == "param0" &&
    464                  inst != computation_->parameter_instruction(0)) {
    465         // Check that the matched parameter == parameter 0 from 'computation_'.
    466         return InvalidArgument("Unexpected Parameter0 instruction : %s",
    467                                inst->name().c_str());
    468       } else if (tag == "gte.fusion_param.param0") {
    469         // Fusion parameter: lookup and compare with associated fusion operand.
    470         CHECK_EQ(HloOpcode::kParameter, inst->opcode());
    471         CHECK(inst->IsFused());
    472         if (inst->parent()->FusionInstruction()->operand(
    473                 inst->parameter_number()) !=
    474             computation_->parameter_instruction(0)) {
    475           return InvalidArgument("Could not match fusion param: %s",
    476                                  inst->name().c_str());
    477         }
    478       }
    479     }
    480     return tensorflow::Status::OK();
    481   }
    482 
    483   const HloComputation* computation_;
    484   const int64 tuple_index_;
    485 
    486   int64 loop_increment_ = -1;
    487 
    488   TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher);
    489 };
    490 
    491 }  // namespace
    492 
    493 StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
    494     const HloInstruction* while_hlo) {
    495   if (while_hlo->opcode() != HloOpcode::kWhile) {
    496     return InvalidArgument("Expected While instruction.");
    497   }
    498 
    499   WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition());
    500   TF_RETURN_IF_ERROR(cond_matcher.Run());
    501 
    502   WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index());
    503   TF_RETURN_IF_ERROR(init_matcher.Run());
    504 
    505   WhileBodyComputationMatcher body_matcher(while_hlo->while_body(),
    506                                            cond_matcher.tuple_index());
    507   TF_RETURN_IF_ERROR(body_matcher.Run());
    508 
    509   // Check for valid For loop parameters.
    510   if (init_matcher.loop_start() >= cond_matcher.loop_limit()) {
    511     return InvalidArgument("Loop start must be less than loop limit.");
    512   }
    513   if (body_matcher.loop_increment() <= 0) {
    514     return InvalidArgument("Loop increment must greater than zero.");
    515   }
    516   return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(),
    517                          body_matcher.loop_increment());
    518 }
    519 
    520 }  // namespace gpu
    521 }  // namespace xla
    522