1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h" 17 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/status_macros.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 28 namespace xla { 29 namespace gpu { 30 31 namespace { 32 33 // TODO(b/33483676) Use an expression tree to specify computations to pattern 34 // match for while transformations. 35 36 // ExprTree is a simple recursive data structure used to express computation 37 // patterns to match. 38 // 39 // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each 40 // of type ExprTree). Operands can be added by specifying the index and 41 // HloOpcode of the operand. 42 // 43 // For example, the following computation: 44 // 45 // Parameter 46 // | 47 // Const GetTupleElement 48 // \ / 49 // Add (root) 50 // 51 // Can be matched with the following expression tree: 52 // 53 // ExprTree add(HloOpcode::kAdd, 54 // ExprTree(HloOpcode::kConstant), 55 // ExprTree(HloOpcode::kGetTupleElement, 56 // tuple_index, ExprTree(HloOpcode::kParameter))); 57 // 58 // Match the ExprTree root against an Hlo graph: 59 // 60 // ExprTree::TaggedInstructionMap tagged_instructions; 61 // TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), 62 // &tagged_instructions)); 63 // 64 // Instructions that are "tagged" with a context-specific string will 65 // be returned in 'tagged_instructions' for further processing (i.e. parsing 66 // constants or recording the tuple_index). 67 // 68 class ExprTree { 69 public: 70 explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} 71 ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} 72 ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { 73 SetOperand(0, operand0); 74 } 75 ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) 76 : opcode_(opcode) { 77 SetOperand(index0, operand0); 78 } 79 ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, 80 int64 index1, const ExprTree& operand1) 81 : opcode_(opcode) { 82 SetOperand(index0, operand0); 83 SetOperand(index1, operand1); 84 } 85 ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) 86 : opcode_(opcode), tag_(tag) { 87 SetOperand(0, operand0); 88 } 89 ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) 90 : opcode_(opcode) { 91 SetOperand(0, operand0); 92 SetOperand(1, operand1); 93 } 94 95 ExprTree(const ExprTree& to_copy) { 96 opcode_ = to_copy.opcode_; 97 tag_ = to_copy.tag_; 98 if (to_copy.fused_root_tree_ != nullptr) { 99 fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); 100 } 101 for (auto& pair : to_copy.operands_) { 102 CHECK(operands_.find(pair.first) == operands_.end()); 103 operands_.insert(std::make_pair( 104 pair.first, std::unique_ptr<ExprTree>(new ExprTree(*pair.second)))); 105 } 106 } 107 108 void SetFusedRoot(const ExprTree& fused_root) { 109 fused_root_tree_.reset(new ExprTree(fused_root)); 110 } 111 112 typedef std::unordered_map<string, const HloInstruction*> 113 TaggedInstructionMap; 114 115 // Matches 'instruction' HloOpcode against 'opcode_'. 116 // Recursively matches each operand in 'operands_'. 117 // Recursively matches fused instructions starting at 'fused_root_tree_' 118 // if 'opcode_ == kFusion'. 119 // Returns OK status, and instructions in 'tagged_instructions' for each 120 // matched ExprTree node with a non-empty 'tag_'. 121 // Returns error message on failure. 122 Status Match(const HloInstruction* instruction, 123 TaggedInstructionMap* tagged_instructions) const { 124 if (opcode_ != instruction->opcode()) { 125 return InvalidArgument("got opcode %s, want %s", 126 HloOpcodeString(instruction->opcode()).c_str(), 127 HloOpcodeString(opcode_).c_str()); 128 } 129 130 VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; 131 if (!tag_.empty()) { 132 tagged_instructions->insert({tag_, instruction}); 133 } 134 135 if (instruction->opcode() == HloOpcode::kFusion) { 136 CHECK(fused_root_tree_ != nullptr); 137 // Match fused instructions for this node starting a 'fused_root_tree'. 138 TF_RETURN_IF_ERROR(fused_root_tree_->Match( 139 instruction->fused_expression_root(), tagged_instructions)); 140 } 141 142 // Match each operand in 'operands_'. 143 for (auto& pair : operands_) { 144 TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), 145 tagged_instructions)); 146 } 147 return tensorflow::Status::OK(); 148 } 149 150 private: 151 void SetOperand(int64 index, const ExprTree& operand) { 152 CHECK_EQ(0, operands_.count(index)); 153 operands_.insert(std::make_pair(index, MakeUnique<ExprTree>(operand))); 154 } 155 156 HloOpcode opcode_; 157 std::unordered_map<int64, std::unique_ptr<ExprTree>> operands_; 158 std::unique_ptr<ExprTree> fused_root_tree_; 159 string tag_; 160 }; 161 162 // MatcherBase is a base class that provides common functionality for 163 // sub-classes which match specific target sub-computations (i.e. loop 164 // induction variable initialization, comparison and update). 165 class MatcherBase { 166 public: 167 MatcherBase() {} 168 virtual ~MatcherBase() {} 169 170 // Attempts to match each ExprTree in 'expr_trees_'. 171 // Returns OK on the first successful match, error status otherwise. 172 virtual tensorflow::Status Run() { 173 Status status; 174 for (const ExprTree& expr_tree : expr_trees_) { 175 status = MatchExprTree(expr_tree); 176 if (status.ok()) { 177 return status; 178 } 179 } 180 return status; 181 } 182 183 virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; 184 185 // Returns the constant value parsed form kConstant 'instruction'. 186 // Returns error status otherwise. 187 Status ParseConstInteger(const HloInstruction* instruction, 188 int64* const_value) const { 189 CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); 190 PrimitiveType element_type = instruction->shape().element_type(); 191 if (element_type != S32 && element_type != S64) { 192 return InvalidArgument("Expected constant of integral type."); 193 } 194 const Literal& literal = instruction->literal(); 195 PrimitiveType type = literal.shape().element_type(); 196 if (type != S32 && type != S64) { 197 return InvalidArgument("Must use S32 or S64 integral types."); 198 } 199 if (type == S32) { 200 *const_value = static_cast<int64>(literal.GetFirstElement<int32>()); 201 } else if (type == S64) { 202 *const_value = literal.GetFirstElement<int64>(); 203 } 204 return tensorflow::Status::OK(); 205 } 206 207 StatusOr<const HloInstruction*> GetTaggedInstruction( 208 const string& tag, 209 const ExprTree::TaggedInstructionMap& tagged_instructions) { 210 auto it = tagged_instructions.find(tag); 211 if (it == tagged_instructions.end()) { 212 return InvalidArgument("Cound not find instruction for tag: %s", 213 tag.c_str()); 214 } 215 return it->second; 216 } 217 218 protected: 219 std::vector<ExprTree> expr_trees_; 220 221 private: 222 TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); 223 }; 224 225 // WhileConditionComputationMatcher attempts to match a target computation 226 // pattern in the while condition sub-computation. 227 // If the target pattern is matched, two pieces of information are extracted 228 // from 'tagged' instructions returned by the matcher: 229 // 230 // *) 'tuple_index': 231 // *) The loop induction variable tuple_index from the GetTupleElement 232 // instruction of the matched computation. 233 // *) Used in subsequent matching passes of while init operand and body 234 // computations to select loop induction variable tuple element. 235 // 236 // *) 'loop_limit': 237 // *) The integral value from Constant root operand in matched computation. 238 // *) Used as the constant for the loop limit. 239 // 240 class WhileConditionComputationMatcher : public MatcherBase { 241 public: 242 explicit WhileConditionComputationMatcher(const HloComputation* computation) 243 : computation_(computation) { 244 expr_trees_.emplace_back(BuildCondExprTree()); 245 } 246 247 int64 loop_limit() const { return loop_limit_; } 248 int64 tuple_index() const { return tuple_index_; } 249 250 private: 251 // Builds expression tree for the following condition computation: 252 // 253 // Const Parameter 254 // \ / 255 // Fusion ------------> FusionParam FusionParam 256 // \ / 257 // GTE / 258 // \ / 259 // LessThan (fused root) 260 // 261 ExprTree BuildCondExprTree() { 262 // Build ExprTree for fused instructions. 263 ExprTree fused_root( 264 HloOpcode::kLt, 265 ExprTree(HloOpcode::kGetTupleElement, "gte", 266 ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), 267 ExprTree(HloOpcode::kParameter)); 268 269 // Build top-level computation. 270 ExprTree root(HloOpcode::kFusion, 271 ExprTree(HloOpcode::kConstant, "loop_limit"), 272 ExprTree(HloOpcode::kParameter, "param0")); 273 274 root.SetFusedRoot(fused_root); 275 return root; 276 } 277 278 Status MatchExprTree(const ExprTree& expr_tree) override { 279 VLOG(2) << "MATCHING while condition"; 280 ExprTree::TaggedInstructionMap tagged_instructions; 281 TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), 282 &tagged_instructions)); 283 284 // Get tagged GTE instruction and set 'tuple_index_'. 285 TF_ASSIGN_OR_RETURN(const HloInstruction* gte, 286 GetTaggedInstruction("gte", tagged_instructions)); 287 tuple_index_ = gte->tuple_index(); 288 289 // Get tagged Constant instruction and parse 'loop_limit_'. 290 TF_ASSIGN_OR_RETURN( 291 const HloInstruction* const_hlo, 292 GetTaggedInstruction("loop_limit", tagged_instructions)); 293 TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); 294 295 // Get tagged "param0" instruction, and check that it matches 296 // 'computation_' parameter 0. 297 TF_ASSIGN_OR_RETURN(const HloInstruction* param0, 298 GetTaggedInstruction("param0", tagged_instructions)); 299 if (param0 != computation_->parameter_instruction(0)) { 300 return InvalidArgument("Unexpected Parameter0 instruction : %s", 301 param0->name().c_str()); 302 } 303 304 // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, 305 // and compare it to 'computation_' parameter0. 306 TF_ASSIGN_OR_RETURN( 307 const HloInstruction* gte_fusion_param0, 308 GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); 309 CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); 310 CHECK(gte_fusion_param0->IsFused()); 311 if (gte_fusion_param0->parent()->FusionInstruction()->operand( 312 gte_fusion_param0->parameter_number()) != 313 computation_->parameter_instruction(0)) { 314 return InvalidArgument("Could not match fusion param: %s", 315 gte_fusion_param0->name().c_str()); 316 } 317 318 return tensorflow::Status::OK(); 319 } 320 321 const HloComputation* computation_; 322 323 int64 loop_limit_ = -1; 324 int64 tuple_index_ = -1; 325 326 TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); 327 }; 328 329 // WhileInitOperandMatcher matches a target computation pattern of the 330 // while instructions 'init' operand, indexing the tuple at 'tuple_index'. 331 // On success, parses constant 'loop_start' which represents the loop induction 332 // variable start values, then returns OK. 333 // Returns error status otherwise. 334 class WhileInitOperandMatcher : public MatcherBase { 335 public: 336 WhileInitOperandMatcher(const HloInstruction* while_hlo, 337 const int64 tuple_index) 338 : while_hlo_(while_hlo), tuple_index_(tuple_index) { 339 expr_trees_.emplace_back(BuildInitExprTree()); 340 } 341 342 int64 loop_start() const { return loop_start_; } 343 344 private: 345 // Builds expression tree for the following while init operand subcomputation: 346 // 347 // Const 348 // | 349 // Copy 350 // | 351 // Tuple0 352 // | 353 // While 354 // 355 ExprTree BuildInitExprTree() { 356 return ExprTree( 357 HloOpcode::kWhile, "while", 358 ExprTree(HloOpcode::kTuple, tuple_index_, 359 ExprTree(HloOpcode::kCopy, 360 ExprTree(HloOpcode::kConstant, "loop_start")))); 361 } 362 363 Status MatchExprTree(const ExprTree& expr_tree) override { 364 VLOG(2) << "MATCHING while init"; 365 ExprTree::TaggedInstructionMap tagged_instructions; 366 TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); 367 368 // Get tagged while instruction check against 'while_hlo_'. 369 TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, 370 GetTaggedInstruction("while", tagged_instructions)); 371 if (while_hlo != while_hlo_) { 372 return InvalidArgument("Expected While for instruction : %s", 373 while_hlo->name().c_str()); 374 } 375 376 // Get tagged Constant instruction and parse 'loop_start_'. 377 TF_ASSIGN_OR_RETURN( 378 const HloInstruction* const_hlo, 379 GetTaggedInstruction("loop_start", tagged_instructions)); 380 TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); 381 382 return tensorflow::Status::OK(); 383 } 384 385 const HloInstruction* while_hlo_; 386 const int64 tuple_index_; 387 388 int64 loop_start_ = -1; 389 390 TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); 391 }; 392 393 // WhileBodyComputationMatcher matches a target computation pattern for 394 // the loop induction variable update. Matching proceeds from the while body 395 // computation root[tuple_index] to param[tuple_index], where 'tuple_index' 396 // If the target pattern is matched, parses a constant which represents the 397 // loop induction variable increment value, then returns status OK. 398 // Returns error status otherwise. 399 class WhileBodyComputationMatcher : public MatcherBase { 400 public: 401 WhileBodyComputationMatcher(const HloComputation* computation, 402 const int64 tuple_index) 403 : computation_(computation), tuple_index_(tuple_index) { 404 expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); 405 expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); 406 } 407 408 int64 loop_increment() const { return loop_increment_; } 409 410 private: 411 // Builds expression tree for the following while body computation: 412 // 413 // 414 // FusionParam FusionParam 415 // \ / 416 // Const Param \ GTE1 417 // \ / \ / 418 // Fusion -----------> Add 419 // | 420 // Copy 421 // | 422 // Tuple0 423 // 424 ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { 425 // Build ExprTree for fused instructions. 426 ExprTree gte1 = 427 ExprTree(HloOpcode::kGetTupleElement, "gte", 428 ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); 429 ExprTree fused_root(HloOpcode::kAdd, const_index, 430 ExprTree(HloOpcode::kParameter), gte_index, gte1); 431 432 // Build fusion instruction (and set fused root). 433 ExprTree fusion(HloOpcode::kFusion, 0, 434 ExprTree(HloOpcode::kConstant, "loop_increment"), 1, 435 ExprTree(HloOpcode::kParameter, "param0")); 436 fusion.SetFusedRoot(fused_root); 437 438 // Build top-level computation. 439 ExprTree tuple0(HloOpcode::kTuple, tuple_index_, 440 ExprTree(HloOpcode::kCopy, fusion)); 441 return tuple0; 442 } 443 444 Status MatchExprTree(const ExprTree& expr_tree) override { 445 VLOG(2) << "MATCHING while body"; 446 ExprTree::TaggedInstructionMap tagged_instructions; 447 TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), 448 &tagged_instructions)); 449 450 for (const auto& pair : tagged_instructions) { 451 const auto& tag = pair.first; 452 const auto& inst = pair.second; 453 454 if (tag == "gte" && inst->tuple_index() != tuple_index_) { 455 // Check that the matched GTE instruction is at the 'tuple_index' we 456 // matched in the while condition computation. 457 return InvalidArgument("Unexpected tuple index instruction : %s", 458 inst->name().c_str()); 459 } else if (tag == "loop_increment") { 460 // Parse the constant which represents the loop induction variable 461 // increment value. 462 TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); 463 } else if (tag == "param0" && 464 inst != computation_->parameter_instruction(0)) { 465 // Check that the matched parameter == parameter 0 from 'computation_'. 466 return InvalidArgument("Unexpected Parameter0 instruction : %s", 467 inst->name().c_str()); 468 } else if (tag == "gte.fusion_param.param0") { 469 // Fusion parameter: lookup and compare with associated fusion operand. 470 CHECK_EQ(HloOpcode::kParameter, inst->opcode()); 471 CHECK(inst->IsFused()); 472 if (inst->parent()->FusionInstruction()->operand( 473 inst->parameter_number()) != 474 computation_->parameter_instruction(0)) { 475 return InvalidArgument("Could not match fusion param: %s", 476 inst->name().c_str()); 477 } 478 } 479 } 480 return tensorflow::Status::OK(); 481 } 482 483 const HloComputation* computation_; 484 const int64 tuple_index_; 485 486 int64 loop_increment_ = -1; 487 488 TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); 489 }; 490 491 } // namespace 492 493 StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor( 494 const HloInstruction* while_hlo) { 495 if (while_hlo->opcode() != HloOpcode::kWhile) { 496 return InvalidArgument("Expected While instruction."); 497 } 498 499 WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); 500 TF_RETURN_IF_ERROR(cond_matcher.Run()); 501 502 WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); 503 TF_RETURN_IF_ERROR(init_matcher.Run()); 504 505 WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), 506 cond_matcher.tuple_index()); 507 TF_RETURN_IF_ERROR(body_matcher.Run()); 508 509 // Check for valid For loop parameters. 510 if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { 511 return InvalidArgument("Loop start must be less than loop limit."); 512 } 513 if (body_matcher.loop_increment() <= 0) { 514 return InvalidArgument("Loop increment must greater than zero."); 515 } 516 return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), 517 body_matcher.loop_increment()); 518 } 519 520 } // namespace gpu 521 } // namespace xla 522