1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ 18 19 #include "absl/strings/str_replace.h" 20 #include "absl/strings/string_view.h" 21 #include "absl/utility/utility.h" 22 #include "tensorflow/compiler/xla/layout_util.h" 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 30 namespace xla { 31 32 // A pattern matcher for HloInstructions, Shapes, and Layouts. 33 // 34 // The Match function's first argument must be HloInstruction*, Shape*, or 35 // Layout*. The second argument is a pattern that will be matched against the 36 // first argument, as described below. 37 // 38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout 39 // functions. By default, the returned patterns will match any HloInstruction, 40 // Shape, or Layout, respectively. However the match can be made more specific 41 // by using the pattern's modifier methods, for example: 42 // 43 // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand( 44 // 0, match::Op().WithOpcode(HloOpcode::kConstant)) 45 // 46 // This pattern will match Add instructions whose first operand is a constant. 47 // 48 // Each pattern type has the following modifiers, which are described where 49 // nontrivial. 50 // 51 // Op(): 52 // - Is: is the given HloInstruction* (i.e. pointer equality) 53 // - WithName 54 // - WithOpcode 55 // - WithoutOpcode: anything other than the given opcode 56 // - WithShape: instr's shape matches the given pattern 57 // - WithShapeEqualTo: instr's shape is equal to the given Shape 58 // - WithShapeCompatibleTo: instr's shape is compatible with the given Shape 59 // - WithNumOperands 60 // - WithOperand: operand at the given index matches the given pattern 61 // - IsConstant 62 // - IsNonConstant 63 // - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value, 64 // e.g. IsConstantScalar() or IsConstantScalar(42). 65 // - WithFusionKind 66 // - WithTupleIndex: get-tuple-element operations with the given tuple index 67 // - WithOneUse: Instruction is used as an operand exactly once. 68 // - WithOneUser: Instruction is used by exactly one other instruction, but 69 // is possibly used more than once as an operand (e.g. multiply(x,x)). 70 // - WithComparisonDirection: instr has the given direction 71 // 72 // Shape(): 73 // - EqualTo 74 // - CompatibleTo 75 // - IsScalar/IsEffectiveScalar/IsArray/IsTuple 76 // - IsDenseArray/IsSparseArray 77 // - WithLayout: layout shape's layout matches the given pattern (e.g. 78 // Layout().WithDenseFormat()) 79 // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another 80 // Layout, but not the result of Layout().foo()) 81 // - WithSubshape: shape is a tuple whose subshape matches the given pattern 82 // (e.g. Shape().IsScalar()). 83 // - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg 84 // (i.e. another Shape, but not the result of Shape().foo()) 85 // - WithElementType: shape is an array/scalar with the given elem type 86 // - WithRank: shape is an array/scalar with the given rank 87 // 88 // Layout(): 89 // - EqualTo 90 // - WithDenseFormat/WithSparseFormat 91 // 92 // Op(), Shape(), and Layout() may be passed an argument of type 93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of 94 // these pointers. If the pattern is matched, the address of the matched value 95 // will be "captured" and stored at this location. 96 // 97 // For example: 98 // HloInstruction* foo = ...; 99 // HloInstruction* matched_operand; 100 // CHECK(Match(foo, 101 // match::Op().WithOperand(0, match::Op(&matched_operand)))); 102 // 103 // Helpers are provided for most HLO instructions. These helpers can be called 104 // with no arguments, in which case they will match any instruction matching the 105 // opcode. They may also be called with matches for the operands and with an 106 // optional capture. (The capture must be the first argument.) Some examples of 107 // these helpers and their equivalents are provided below. 108 109 // Example nullary instruction: 110 // Parameter() == Op().WithOpcode(HloOpcode::kParameter) 111 // Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter) 112 // 113 // Example unary instruction: 114 // Abs() == Op().WithOpcode(HloOpcode::kAbs) 115 // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) 116 // .WithOperand(0, Op(&a))) 117 // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) 118 // .WithOperand(0, Op(&b)) 119 // 120 // Commutative binary instructions have a special form that accepts either order 121 // of args, e.g.: 122 // 123 // AddAnyOrder(Parameter(1), Abs()) == 124 // Op().WithOpcode(HloOpcode::kAdd) 125 // .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs()); 126 // 127 // MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`. 128 // 129 // The following additional helpers are provided. In all cases, `&a` is 130 // optional. 131 // 132 // ConstantScalar(&a) == Op(&a).IsConstantScalar(); 133 // ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v); 134 // ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar(); 135 // ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v) 136 // NonConstant(&a) == Op(&a).IsNonConstant() 137 // GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index) 138 // .WithOperand(0, b); 139 // Parameter(&a, n) == Op(&a).WithParameterNum(n); 140 141 struct MatchOption { 142 // If true, actually capture matched item into the user pointer. 143 bool capture; 144 145 // An explanation for why we failed to match is streamed here, if not-null. 146 std::ostream* explain_os; 147 }; 148 149 template <typename Value, typename Pattern> 150 bool Match(Value* value, const Pattern& pattern, 151 MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) { 152 if (option.capture) { 153 auto new_option = option; 154 new_option.capture = false; 155 if (!pattern.Match(value, new_option)) { 156 return false; 157 } 158 } 159 return pattern.Match(value, option); 160 } 161 162 namespace match { 163 164 namespace detail { 165 166 // Macro for streaming to option.explain_os if it's not null. 167 // 168 // EXPLAIN << "value of foo(): " << foo() 169 // 170 #pragma push_macro("EXPLAIN") 171 #define EXPLAIN \ 172 if (option.explain_os) *option.explain_os 173 174 // kIndentInc is the additional number of spaces that we indent by when we 175 // increase the indent "by one". 176 enum { 177 kIndentInc = 2, 178 }; 179 180 // Writes a newline and then `indent` spaces. 181 // 182 // We follow an unintuitive convention in this file's pretty-printers: Indents 183 // are performed by the caller, not the callee. For example, if you want to 184 // print 185 // 186 // foo: 187 // - bar 188 // 189 // you'd do: 190 // 191 // Foo::DescribeTo(std::ostream* os, int64 indent) { 192 // *os << "foo:"; 193 // Indent(os, indent) // Create a newline at the *current* indent level. 194 // *os << " - "; 195 // bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3. 196 // } 197 // 198 // Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; } 199 // 200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is 201 // performed by Foo. This convention allows the caller to decide whether a 202 // matcher is preceded by a newline, which is important e.g. for the AllOf 203 // matcher. 204 // 205 // (Incidentally, indenting in Match's explanations is handled differently. 206 // Indents are a common case in DescribeTo [we're printing a whole tree], but 207 // they're a special case in Match [we're printing only a path through the tree 208 // that encounters a failing node]. Indents in Match only appear when we 209 // encounter a failing disjunction, so we just handle them as a special case 210 // there.) 211 inline void Indent(std::ostream* os, int64 indent) { 212 *os << "\n"; 213 for (int64 i = 0; i < indent; ++i) { 214 *os << " "; 215 } 216 } 217 218 // SFINAE template that determines whether T declares a static member 219 // kIsTrivialMatcher. 220 // 221 // Trivial matchers get special treatment. For example, when printing 222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This 223 // yields e.g. 224 // "a shape compatible with f32[1,2]" 225 // rather than 226 // "a shape AND compatible with f32[1,2]" 227 template <typename T, typename Dummy = void> 228 struct IsTrivialMatcher { 229 static constexpr bool value = false; 230 }; 231 template <typename T> 232 struct IsTrivialMatcher<T, 233 typename std::enable_if<T::kIsTrivialMatcher>::type> { 234 static constexpr bool value = true; 235 }; 236 237 template <typename Item, typename... Patterns> 238 class AllOfPattern { 239 public: 240 explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} 241 242 bool Match(const Item* item, MatchOption option) const { 243 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); 244 // This invariant is guaranteed by the top-level Match and AnyOf. 245 DCHECK(matched || !option.capture); 246 return matched; 247 } 248 249 bool Match(Item* item, MatchOption option) const { 250 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); 251 // This invariant is guaranteed by the top-level Match and AnyOf. 252 DCHECK(matched || !option.capture); 253 return matched; 254 } 255 256 void DescribeTo(std::ostream* os, int64 indent = 0) const { 257 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent); 258 } 259 260 // Accessor for patterns_. Please don't use this outside of this file. 261 const std::tuple<Patterns...>& patterns() const { return patterns_; } 262 263 private: 264 template <typename ItemType, size_t index> 265 bool MatchImpl(ItemType* item, MatchOption option, 266 std::integral_constant<size_t, index>) const { 267 // We don't need to do any EXPLAINing here; it's all correctly handled by 268 // our sub-matchers (if any fail). 269 return std::get<index>(patterns_).Match(item, option) && 270 MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); 271 } 272 273 template <typename ItemType> 274 bool MatchImpl(ItemType* item, MatchOption option, 275 std::integral_constant<size_t, sizeof...(Patterns)>) const { 276 return true; 277 } 278 279 // Pretty-printing a conjunction has some special cases to make it easy to 280 // read in the simple (common) case. 281 // 282 // If sizeof...(Patterns) == 1, prints as e.g. 283 // 284 // a shape 285 // 286 // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a 287 // shape") prints as 288 // 289 // a shape compatible with f32[1,2] 290 // 291 // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as 292 // 293 // a shape: 294 // * compatible with f32[1,2] AND 295 // * that represents a scalar 296 // 297 // Otherwise prints as: 298 // 299 // all of: 300 // * foo AND 301 // * bar 302 // 303 template <size_t index> 304 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>, 305 int64 indent) const { 306 constexpr bool first_is_trivial = 307 IsTrivialMatcher<typename std::remove_reference<decltype( 308 std::get<0>(patterns_))>::type>::value; 309 constexpr bool is_last = index == sizeof...(Patterns) - 1; 310 const auto& submatcher = std::get<index>(patterns_); 311 312 auto print_bulleted_item = [&] { 313 *os << " * "; 314 submatcher.DescribeTo(os, indent + 3); 315 if (!is_last) { 316 *os << " AND"; 317 Indent(os, indent); 318 } 319 }; 320 321 if (index == 0) { 322 if (first_is_trivial || is_last) { 323 submatcher.DescribeTo(os, indent + kIndentInc); 324 if (sizeof...(Patterns) > 2) { 325 *os << ":"; 326 Indent(os, indent); 327 } 328 } else { 329 *os << "all of:"; 330 Indent(os, indent); 331 print_bulleted_item(); 332 } 333 } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) { 334 *os << " "; 335 submatcher.DescribeTo(os, indent); 336 } else { 337 print_bulleted_item(); 338 } 339 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent); 340 } 341 342 void DescribeToImpl(std::ostream* os, 343 std::integral_constant<size_t, sizeof...(Patterns)>, 344 int64 indent) const {} 345 346 std::tuple<Patterns...> patterns_; 347 }; 348 349 } // namespace detail 350 351 // Returns a pattern that represents the conjunction of all input patterns. All 352 // patterns need to match in order to have the AllOf pattern match. 353 template <typename Item, typename... Patterns> 354 detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf( 355 const Patterns&... patterns) { 356 return detail::AllOfPattern<typename std::remove_const<Item>::type, 357 Patterns...>(patterns...); 358 } 359 360 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>. 361 // 362 // This transformation is necessary for good pretty-printing. 363 template <typename Item, typename... InnerPs, typename... OuterPs> 364 detail::AllOfPattern<typename std::remove_const<Item>::type, InnerPs..., 365 OuterPs...> 366 AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p, 367 const OuterPs&... outer_ps) { 368 // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>. 369 auto make_all_of = [](const InnerPs&... inner_ps, 370 const OuterPs&... outer_ps) { 371 return detail::AllOfPattern<typename std::remove_const<Item>::type, 372 InnerPs..., OuterPs...>(inner_ps..., 373 outer_ps...); 374 }; 375 return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(), 376 std::make_tuple(outer_ps...))); 377 } 378 379 namespace detail { 380 381 template <typename LayoutType, typename Impl> 382 class LayoutPattern; 383 384 // The base LayoutPattern implementation. Matches only if the layout is not 385 // nullptr. 386 class LayoutPatternBaseImpl { 387 public: 388 bool Match(const ::xla::Layout* layout, MatchOption option) const { 389 if (layout == nullptr) { 390 EXPLAIN << "Layout is null"; 391 return false; 392 } 393 return true; 394 } 395 396 void DescribeTo(std::ostream* os, int64 indent = 0) const { 397 *os << "a layout"; 398 } 399 400 static constexpr bool kIsTrivialMatcher = true; 401 }; 402 403 // A LayoutPattern implementation that matches only if the layout equals a 404 // Layout proto. 405 class LayoutPatternEqualImpl { 406 public: 407 explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout) 408 : layout_(layout) {} 409 410 bool Match(const ::xla::Layout* layout, MatchOption option) const { 411 if (!LayoutUtil::Equal(*layout_, *layout)) { 412 EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout) 413 << " is not equal to expected " 414 << LayoutUtil::HumanString(*layout_); 415 return false; 416 } 417 return true; 418 } 419 420 void DescribeTo(std::ostream* os, int64 indent = 0) const { 421 *os << "equal to " << LayoutUtil::HumanString(*layout_); 422 } 423 424 private: 425 const ::xla::Layout* layout_; 426 }; 427 428 // A LayoutPattern implementation that matches only if the layout has a given 429 // format. 430 class LayoutPatternFormatImpl { 431 public: 432 explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} 433 434 bool Match(const ::xla::Layout* layout, MatchOption option) const { 435 if (layout->format() != format_) { 436 EXPLAIN << "Layout has format " << Format_Name(layout->format()) 437 << " but expected " << Format_Name(format_); 438 return false; 439 } 440 return true; 441 } 442 443 void DescribeTo(std::ostream* os, int64 indent = 0) const { 444 *os << "with format " << Format_Name(format_); 445 } 446 447 private: 448 Format format_; 449 }; 450 451 // A pattern that matches Layouts. 452 template <typename LayoutType, typename Impl> 453 class LayoutPattern { 454 private: 455 template <typename NewImpl> 456 auto AppendImpl(NewImpl new_impl) const 457 -> LayoutPattern<LayoutType, 458 decltype(AllOf<Layout>(std::declval<Impl>(), 459 std::move(new_impl)))> { 460 auto new_allof = AllOf<Layout>(impl_, std::move(new_impl)); 461 return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof), 462 matched_layout_); 463 } 464 465 public: 466 explicit constexpr LayoutPattern(const Impl& impl, 467 LayoutType** matched_layout) 468 : impl_(impl), matched_layout_(matched_layout) {} 469 470 // Returns true and captures the layout iff it matches the pattern. 471 bool Match(const ::xla::Layout* layout, MatchOption option) const { 472 if (impl_.Match(layout, option)) { 473 if (option.capture && matched_layout_) { 474 *matched_layout_ = layout; 475 } 476 return true; 477 } 478 return false; 479 } 480 481 // Returns true and captures the layout iff it matches the pattern. 482 bool Match(::xla::Layout* layout, MatchOption option) const { 483 if (impl_.Match(layout, option)) { 484 if (option.capture && matched_layout_) { 485 *matched_layout_ = layout; 486 } 487 return true; 488 } 489 return false; 490 } 491 492 void DescribeTo(std::ostream* os, int64 indent = 0) const { 493 impl_.DescribeTo(os, indent); 494 } 495 496 // Modifies the pattern to match only if the layout equals the given proto. 497 // The layout must outlive the returned pattern. 498 constexpr auto EqualTo(const ::xla::Layout* layout) const 499 -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { 500 return AppendImpl(LayoutPatternEqualImpl(layout)); 501 } 502 503 // Modifies the pattern to match only if the layout has a dense format. 504 constexpr auto WithDenseFormat() const 505 -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { 506 return AppendImpl(LayoutPatternFormatImpl(DENSE)); 507 } 508 509 // Modifies the pattern to match only if the layout has a sparse format. 510 constexpr auto WithSparseFormat() const 511 -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { 512 return AppendImpl(LayoutPatternFormatImpl(SPARSE)); 513 } 514 515 private: 516 Impl impl_; 517 LayoutType** matched_layout_; 518 }; 519 520 template <typename Item, typename... Patterns> 521 class AnyOfPattern { 522 public: 523 explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} 524 525 bool Match(const Item* item, MatchOption option) const { 526 return MatchImpl(item, option); 527 } 528 529 bool Match(Item* item, MatchOption option) const { 530 return MatchImpl(item, option); 531 } 532 533 void DescribeTo(std::ostream* os, int64 indent = 0) const { 534 *os << "any of:"; 535 Indent(os, indent); 536 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent); 537 } 538 539 private: 540 template <typename ItemType> 541 bool MatchImpl(ItemType* item, MatchOption option) const { 542 // If we're generating an explanation, buffer it until we know we failed. 543 absl::optional<std::stringstream> explanation; 544 MatchOption new_option = option; 545 if (option.explain_os) { 546 new_option.explain_os = &explanation.emplace(); 547 } 548 bool rv = MatchRecursiveImpl(item, new_option, 549 std::integral_constant<size_t, 0>()); 550 if (!rv && option.explain_os) { 551 EXPLAIN << "None of the following matchers succeeded:"; 552 EXPLAIN << explanation->str(); 553 } 554 return rv; 555 } 556 557 template <typename ItemType, size_t index> 558 bool MatchRecursiveImpl(ItemType* item, MatchOption option, 559 std::integral_constant<size_t, index>) const { 560 auto new_option = option; 561 new_option.capture = false; 562 563 absl::optional<std::stringstream> explanation; 564 if (option.explain_os) { 565 new_option.explain_os = &explanation.emplace(); 566 } 567 568 // Try to match the sub-pattern without capturing behavior. 569 if (std::get<index>(patterns_).Match(item, new_option)) { 570 // Capture the branch. 571 if (option.capture) { 572 // TODO(timshen): Currently the behavior can be exponential. Optimize it 573 // with memoization or recording the matched sub-pattern index, if it 574 // takes too long to run. 575 // 576 // Specifically, the "memoization" approach is to create an empty 577 // container with the key (pattern, instruction), and value as whether 578 // matched or not. 579 // 580 // Alternatively, we may run the pattern matching with captures off, but 581 // instead record a "trace" somewhere, indicating how exactly the 582 // pattern matches the input. For example, the trace information for 583 // AnyOf will be a runtime number indicate which sub-pattern is matched. 584 // Then we run another pass to do captures only with the help of the 585 // trace. 586 bool matched = std::get<index>(patterns_).Match(item, option); 587 DCHECK(matched); 588 } 589 return true; 590 } 591 if (option.explain_os) { 592 EXPLAIN << "\nMatcher #" << index + 1; 593 EXPLAIN << "\n - "; 594 std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3); 595 EXPLAIN << "\nfailed with"; 596 EXPLAIN << "\n - "; 597 EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}}); 598 } 599 return MatchRecursiveImpl(item, option, 600 std::integral_constant<size_t, index + 1>()); 601 } 602 603 template <typename ItemType> 604 bool MatchRecursiveImpl( 605 ItemType* item, MatchOption option, 606 std::integral_constant<size_t, sizeof...(Patterns)>) const { 607 return false; 608 } 609 610 template <size_t index> 611 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>, 612 int64 indent) const { 613 *os << " - "; 614 std::get<index>(patterns_).DescribeTo(os, indent + 3); 615 if (index != sizeof...(Patterns) - 1) { 616 *os << " OR"; 617 Indent(os, indent); 618 } 619 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent); 620 } 621 622 void DescribeToImpl(std::ostream* os, 623 std::integral_constant<size_t, sizeof...(Patterns)>, 624 int64 indent) const {} 625 626 std::tuple<Patterns...> patterns_; 627 }; 628 629 } // namespace detail 630 631 // Returns a pattern that represents the logical disjunction of the input 632 // patterns. The returned pattern matches from left to right, and stops on the 633 // first match. 634 template <typename Item, typename... Patterns> 635 detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf( 636 const Patterns&... patterns) { 637 return detail::AnyOfPattern<typename std::remove_const<Item>::type, 638 Patterns...>(patterns...); 639 } 640 641 // Creates a layout pattern that will capture the matched layout in the 642 // argument. 643 inline constexpr detail::LayoutPattern<const ::xla::Layout, 644 detail::LayoutPatternBaseImpl> 645 Layout(const ::xla::Layout** matched_layout = nullptr) { 646 return detail::LayoutPattern<const ::xla::Layout, 647 detail::LayoutPatternBaseImpl>( 648 detail::LayoutPatternBaseImpl(), matched_layout); 649 } 650 651 // Creates a layout pattern that will capture the matched layout in the 652 // argument. 653 inline constexpr detail::LayoutPattern<::xla::Layout, 654 detail::LayoutPatternBaseImpl> 655 Layout(::xla::Layout** matched_layout) { 656 return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( 657 detail::LayoutPatternBaseImpl(), matched_layout); 658 } 659 660 namespace detail { 661 662 template <typename ShapeType, typename Impl> 663 class ShapePattern; 664 665 // The base ShapePattern implementation. Matches only if the shape is not 666 // nullptr. 667 class ShapePatternBaseImpl { 668 public: 669 bool Match(const ::xla::Shape* shape, MatchOption option) const { 670 if (shape == nullptr) { 671 EXPLAIN << "Shape is null"; 672 } 673 return shape != nullptr; 674 } 675 676 void DescribeTo(std::ostream* os, int64 indent = 0) const { 677 *os << "a shape"; 678 } 679 680 static constexpr bool kIsTrivialMatcher = true; 681 }; 682 683 // A ShapePattern implementation that matches only if the shape equals a Shape 684 // proto. 685 class ShapePatternEqualImpl { 686 public: 687 explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape) 688 : shape_(shape) {} 689 690 bool Match(const ::xla::Shape* shape, MatchOption option) const { 691 if (!ShapeUtil::Equal(*shape_, *shape)) { 692 EXPLAIN << "Shape not equal to " 693 << ShapeUtil::HumanStringWithLayout(*shape_); 694 return false; 695 } 696 return true; 697 } 698 699 void DescribeTo(std::ostream* os, int64 indent = 0) const { 700 *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_); 701 } 702 703 private: 704 const ::xla::Shape* shape_; 705 }; 706 707 // A ShapePattern implementation that matches only if the shape is compatible to 708 // a Shape proto. 709 class ShapePatternCompatibleImpl { 710 public: 711 explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape) 712 : shape_(shape) {} 713 714 bool Match(const ::xla::Shape* shape, MatchOption option) const { 715 if (!ShapeUtil::Compatible(*shape_, *shape)) { 716 EXPLAIN << "Shape not compatible with " 717 << ShapeUtil::HumanString(*shape_); 718 return false; 719 } 720 return true; 721 } 722 723 void DescribeTo(std::ostream* os, int64 indent = 0) const { 724 *os << "compatible with " << ShapeUtil::HumanString(*shape_); 725 } 726 727 private: 728 const ::xla::Shape* shape_; 729 }; 730 731 // A ShapePattern implementation that matches only if the shape has a given 732 // element type. 733 class ShapePatternElementTypeImpl { 734 public: 735 explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type) 736 : element_type_(element_type) {} 737 738 bool Match(const ::xla::Shape* shape, MatchOption option) const { 739 if (shape->element_type() != element_type_) { 740 EXPLAIN << "Shape does not have element type " 741 << PrimitiveType_Name(element_type_); 742 return false; 743 } 744 return true; 745 } 746 747 void DescribeTo(std::ostream* os, int64 indent = 0) const { 748 *os << "with element type " << PrimitiveType_Name(element_type_); 749 } 750 751 private: 752 PrimitiveType element_type_; 753 }; 754 755 // A ShapePattern implementation that matches only if the shape is scalar. 756 class ShapePatternIsScalarImpl { 757 public: 758 explicit constexpr ShapePatternIsScalarImpl() {} 759 760 bool Match(const ::xla::Shape* shape, MatchOption option) const { 761 if (!ShapeUtil::IsScalar(*shape)) { 762 EXPLAIN << "Shape is not a scalar"; 763 return false; 764 } 765 return true; 766 } 767 768 void DescribeTo(std::ostream* os, int64 indent = 0) const { 769 *os << "that represents a scalar"; 770 } 771 }; 772 773 // A ShapePattern implementation that matches only if the shape is an array 774 class ShapePatternIsArrayImpl { 775 public: 776 explicit constexpr ShapePatternIsArrayImpl() {} 777 778 bool Match(const ::xla::Shape* shape, MatchOption option) const { 779 if (!shape->IsArray()) { 780 EXPLAIN << "Shape is not an array"; 781 return false; 782 } 783 return true; 784 } 785 786 void DescribeTo(std::ostream* os, int64 indent = 0) const { 787 *os << "that represents an array"; 788 } 789 }; 790 791 // A ShapePattern implementation that matches only if the shape is a tuple. 792 class ShapePatternIsTupleImpl { 793 public: 794 explicit constexpr ShapePatternIsTupleImpl() {} 795 796 bool Match(const ::xla::Shape* shape, MatchOption option) const { 797 if (!shape->IsTuple()) { 798 EXPLAIN << "Shape is not a tuple"; 799 return false; 800 } 801 return true; 802 } 803 804 void DescribeTo(std::ostream* os, int64 indent = 0) const { 805 *os << "that represents a tuple"; 806 } 807 }; 808 809 // A ShapePattern implementation that matches only if the shape is an effective 810 // scalar. 811 class ShapePatternEffectiveScalarImpl { 812 public: 813 explicit constexpr ShapePatternEffectiveScalarImpl() {} 814 815 bool Match(const ::xla::Shape* shape, MatchOption option) const { 816 if (!ShapeUtil::IsEffectiveScalar(*shape)) { 817 EXPLAIN << "Shape is not an effective scalar"; 818 return false; 819 } 820 return true; 821 } 822 823 void DescribeTo(std::ostream* os, int64 indent = 0) const { 824 *os << "that is an effective scalar"; 825 } 826 }; 827 828 // A ShapePattern implementation that matches only if the shape has a given 829 // rank. 830 class ShapePatternRankImpl { 831 public: 832 explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} 833 834 bool Match(const ::xla::Shape* shape, MatchOption option) const { 835 if (shape->rank() != rank_) { 836 if (rank_ == 0) { 837 EXPLAIN << "Shape is not a scalar"; 838 } else { 839 EXPLAIN << "Shape does not have rank " << rank_; 840 } 841 return false; 842 } 843 return true; 844 } 845 846 void DescribeTo(std::ostream* os, int64 indent = 0) const { 847 if (rank_ == 0) { 848 *os << "that is a scalar"; 849 } else { 850 *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : ""); 851 } 852 } 853 854 private: 855 int64 rank_; 856 }; 857 858 // A ShapePattern implementation that matches only if the shape has a layout 859 // that matches a given pattern. 860 template <typename LayoutType, typename LayoutImpl> 861 class ShapePatternLayoutImpl { 862 public: 863 explicit constexpr ShapePatternLayoutImpl( 864 const LayoutPattern<LayoutType, LayoutImpl>& layout) 865 : layout_(layout) {} 866 867 bool Match(const ::xla::Shape* shape, MatchOption option) const { 868 return LayoutUtil::HasLayout(*shape) && 869 layout_.Match(&shape->layout(), option); 870 } 871 872 bool Match(Shape* shape, MatchOption option) const { 873 if (!LayoutUtil::HasLayout(*shape)) { 874 EXPLAIN << "Shape does not have a layout"; 875 return false; 876 } 877 if (!layout_.Match(shape->mutable_layout(), option)) { 878 EXPLAIN << "\nin layout"; 879 return false; 880 } 881 return true; 882 } 883 884 void DescribeTo(std::ostream* os, int64 indent = 0) const { 885 *os << "with"; 886 Indent(os, indent + kIndentInc); 887 layout_.DescribeTo(os, indent + kIndentInc); 888 } 889 890 private: 891 LayoutPattern<LayoutType, LayoutImpl> layout_; 892 }; 893 894 // A ShapePattern implementation that matches only if the shape has a subshape 895 // that matches a given pattern. 896 template <typename SubshapeType, typename SubshapeImpl> 897 class ShapePatternSubshapeImpl { 898 public: 899 explicit ShapePatternSubshapeImpl( 900 ShapeIndexView index, 901 const ShapePattern<SubshapeType, SubshapeImpl>& subshape) 902 : index_(index), subshape_(subshape) {} 903 904 bool Match(const ::xla::Shape* shape, MatchOption option) const { 905 return MatchImpl(shape, option); 906 } 907 908 bool Match(::xla::Shape* shape, MatchOption option) const { 909 return MatchImpl(shape, option); 910 } 911 912 void DescribeTo(std::ostream* os, int64 indent = 0) const { 913 *os << "with subshape at index " << index_.ToString() << " which is"; 914 Indent(os, indent + kIndentInc); 915 subshape_.DescribeTo(os, indent + kIndentInc); 916 } 917 918 private: 919 Shape* GetSubshape(Shape* shape) const { 920 return ShapeUtil::GetMutableSubshape(shape, index_); 921 } 922 const Shape* GetSubshape(const Shape* shape) const { 923 return &ShapeUtil::GetSubshape(*shape, index_); 924 } 925 926 template <typename ShapeType> 927 bool MatchImpl(ShapeType* shape, MatchOption option) const { 928 if (!ShapeUtil::IndexIsValid(*shape, index_)) { 929 EXPLAIN << "No subshape at " << index_.ToString(); 930 return false; 931 } 932 if (!subshape_.Match(GetSubshape(shape), option)) { 933 EXPLAIN << "\nin subshape at " << index_.ToString(); 934 return false; 935 } 936 return true; 937 } 938 939 ShapeIndexView index_; 940 ShapePattern<SubshapeType, SubshapeImpl> subshape_; 941 }; 942 943 // A pattern that matches Shapes. 944 template <typename ShapeType, typename Impl> 945 class ShapePattern { 946 private: 947 template <typename NewImpl> 948 auto AppendImpl(NewImpl new_impl) const 949 -> ShapePattern<ShapeType, decltype(AllOf<Shape>(std::declval<Impl>(), 950 std::move(new_impl)))> { 951 auto new_all_of = AllOf<Shape>(impl_, std::move(new_impl)); 952 return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of), 953 matched_shape_); 954 } 955 956 public: 957 explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) 958 : impl_(impl), matched_shape_(matched_shape) {} 959 960 // Returns true and captures the shape iff it matches the pattern. 961 bool Match(const ::xla::Shape* shape, MatchOption option) const { 962 if (impl_.Match(shape, option)) { 963 if (option.capture && matched_shape_) { 964 *matched_shape_ = shape; 965 } 966 return true; 967 } 968 if (shape) { 969 EXPLAIN << "\nin " 970 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) 971 : ShapeUtil::HumanString(*shape)); 972 } 973 return false; 974 } 975 976 // Returns true and captures the shape iff it matches the pattern. 977 bool Match(::xla::Shape* shape, MatchOption option) const { 978 if (impl_.Match(shape, option)) { 979 if (option.capture && matched_shape_) { 980 *matched_shape_ = shape; 981 } 982 return true; 983 } 984 EXPLAIN << "\nin " 985 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) 986 : ShapeUtil::HumanString(*shape)); 987 return false; 988 } 989 990 void DescribeTo(std::ostream* os, int64 indent = 0) const { 991 return impl_.DescribeTo(os, indent); 992 } 993 994 // Modifies the pattern to match only if the shape equals the given proto. 995 // The layout must outlive the returned pattern. 996 constexpr auto EqualTo(const ::xla::Shape* shape) const 997 -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { 998 return AppendImpl(ShapePatternEqualImpl(shape)); 999 } 1000 1001 // Modifies the pattern to match only if the shape is compatible to the given 1002 // proto. The layout must outlive the returned pattern. 1003 constexpr auto CompatibleTo(const ::xla::Shape* shape) const 1004 -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { 1005 return AppendImpl(ShapePatternCompatibleImpl(shape)); 1006 } 1007 1008 // Modifies the pattern to match only if the shape has the given element type. 1009 constexpr auto WithElementType(PrimitiveType element_type) const 1010 -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { 1011 return AppendImpl(ShapePatternElementTypeImpl(element_type)); 1012 } 1013 1014 // Modifies the pattern to match only if the shape is scalar. 1015 constexpr auto IsScalar() const 1016 -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { 1017 return AppendImpl(ShapePatternIsScalarImpl()); 1018 } 1019 1020 // Modifies the pattern to match only if the shape is an array. 1021 constexpr auto IsArray() const 1022 -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { 1023 return AppendImpl(ShapePatternIsArrayImpl()); 1024 } 1025 1026 // Modifies the pattern to match only if the shape is a tuple. 1027 constexpr auto IsTuple() const 1028 -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { 1029 return AppendImpl(ShapePatternIsTupleImpl()); 1030 } 1031 1032 constexpr auto IsEffectiveScalar() const 1033 -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { 1034 return AppendImpl(ShapePatternEffectiveScalarImpl()); 1035 } 1036 1037 // Modifies the pattern to match only if the shape has the given rank. 1038 constexpr auto WithRank(int64 rank) const 1039 -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { 1040 return AppendImpl(ShapePatternRankImpl(rank)); 1041 } 1042 1043 // Modifies the pattern to match only if the shape has a layout that matches 1044 // the given pattern. 1045 template <typename LayoutType, typename LayoutImpl> 1046 auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const 1047 -> decltype(this->AppendImpl( 1048 ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) { 1049 return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout)); 1050 } 1051 1052 constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const 1053 -> decltype(this->WithLayout(Layout().EqualTo(layout))) { 1054 return WithLayout(Layout().EqualTo(layout)); 1055 } 1056 1057 constexpr auto IsDenseArray() const 1058 -> decltype(this->WithLayout(Layout().WithDenseFormat())) { 1059 return WithLayout(Layout().WithDenseFormat()); 1060 } 1061 1062 constexpr auto IsSparseArray() const 1063 -> decltype(this->WithLayout(Layout().WithSparseFormat())) { 1064 return WithLayout(Layout().WithSparseFormat()); 1065 } 1066 1067 // Modifies the pattern to match only if the shape has a subshape that matches 1068 // the given pattern. 1069 template <typename SubshapeType, typename SubshapeImpl> 1070 auto WithSubshape(ShapeIndexView index, 1071 const ShapePattern<SubshapeType, SubshapeImpl>& subshape) 1072 const -> decltype(this->AppendImpl( 1073 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, 1074 subshape))) { 1075 return AppendImpl( 1076 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape)); 1077 } 1078 1079 ShapePattern<ShapeType, 1080 AllOfPattern<Shape, Impl, 1081 ShapePatternSubshapeImpl< 1082 const ::xla::Shape, 1083 AllOfPattern<::xla::Shape, ShapePatternBaseImpl, 1084 ShapePatternEqualImpl>>>> 1085 WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { 1086 return WithSubshape(index, 1087 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>( 1088 ShapePatternBaseImpl(), nullptr) 1089 .EqualTo(shape)); 1090 } 1091 1092 ShapePattern<ShapeType, 1093 AllOfPattern<Shape, Impl, 1094 ShapePatternSubshapeImpl< 1095 const ::xla::Shape, 1096 AllOfPattern<::xla::Shape, ShapePatternBaseImpl, 1097 ShapePatternCompatibleImpl>>>> 1098 WithSubshapeCompatibleTo(ShapeIndexView index, 1099 const ::xla::Shape* shape) const { 1100 return WithSubshape(index, 1101 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>( 1102 ShapePatternBaseImpl(), nullptr) 1103 .CompatibleTo(shape)); 1104 } 1105 1106 private: 1107 Impl impl_; 1108 ShapeType** matched_shape_; 1109 }; 1110 1111 } // namespace detail 1112 1113 // Creates a shape pattern that will capture the matched layout in the argument. 1114 inline constexpr detail::ShapePattern<const ::xla::Shape, 1115 detail::ShapePatternBaseImpl> 1116 Shape(const ::xla::Shape** matched_shape = nullptr) { 1117 return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>( 1118 detail::ShapePatternBaseImpl(), matched_shape); 1119 } 1120 1121 // Creates a shape pattern that will capture the matched layout in the argument. 1122 inline constexpr detail::ShapePattern<::xla::Shape, 1123 detail::ShapePatternBaseImpl> 1124 Shape(::xla::Shape** matched_shape) { 1125 return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( 1126 detail::ShapePatternBaseImpl(), matched_shape); 1127 } 1128 1129 namespace detail { 1130 1131 // Overloads to get a const or non-const operand out of an instruction. 1132 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) { 1133 return instr->mutable_operand(idx); 1134 } 1135 inline const HloInstruction* HloOperand(const HloInstruction* instr, 1136 int64 idx) { 1137 return instr->operand(idx); 1138 } 1139 1140 // Pretty-printer for HloInstruction. Sort of like ToShortString, but with 1141 // fewer %s and more shapes. 1142 inline string InstToString(const HloInstruction* inst) { 1143 return inst->ToString( 1144 HloPrintOptions().set_print_metadata(false).set_print_percent(false)); 1145 } 1146 1147 template <typename HloInstructionType, typename Impl> 1148 class HloInstructionPattern; 1149 1150 // The base HloInstructionPattern implementation. Matches only if the 1151 // instruction is not nullptr. 1152 class HloInstructionPatternBaseImpl { 1153 public: 1154 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1155 if (inst == nullptr) { 1156 EXPLAIN << "HloInstruction* is null"; 1157 return false; 1158 } 1159 return true; 1160 } 1161 1162 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1163 *os << "an HloInstruction"; 1164 } 1165 1166 static constexpr bool kIsTrivialMatcher = true; 1167 }; 1168 1169 // An HloInstructionPattern implementation that matches only if the instruction 1170 // has a given name. 1171 class HloInstructionPatternNameImpl { 1172 public: 1173 explicit HloInstructionPatternNameImpl(absl::string_view name) 1174 : name_(name) {} 1175 1176 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1177 if (inst->name() != name_) { 1178 EXPLAIN << "HloInstruction not named \"" << name_ << "\""; 1179 return false; 1180 } 1181 return true; 1182 } 1183 1184 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1185 *os << "named \"" << name_ << "\""; 1186 } 1187 1188 private: 1189 absl::string_view name_; 1190 }; 1191 1192 // An HloInstructionPattern implementation that matches only if the instruction 1193 // equals a particular pointer. 1194 class HloInstructionIsImpl { 1195 public: 1196 explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {} 1197 1198 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1199 if (inst != inst_) { 1200 EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" 1201 << InstToString(inst_) << ")"; 1202 return false; 1203 } 1204 return true; 1205 } 1206 1207 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1208 *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; 1209 } 1210 1211 private: 1212 const HloInstruction* inst_; 1213 }; 1214 1215 // An HloInstructionPattern implementation that matches only if the instruction 1216 // has a given opcode. 1217 class HloInstructionPatternOpcodeImpl { 1218 public: 1219 explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode, 1220 bool invert) 1221 : opcode_(opcode), invert_(invert) {} 1222 1223 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1224 if (invert_ && inst->opcode() == opcode_) { 1225 EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_) 1226 << ", expected anything else"; 1227 return false; 1228 } 1229 if (!invert_ && inst->opcode() != opcode_) { 1230 EXPLAIN << "HloInstruction doesn't have opcode " 1231 << HloOpcodeString(opcode_); 1232 return false; 1233 } 1234 return true; 1235 } 1236 1237 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1238 if (!invert_) { 1239 *os << "with opcode " << HloOpcodeString(opcode_); 1240 } else { 1241 *os << "with any opcode other than " << HloOpcodeString(opcode_); 1242 } 1243 } 1244 1245 private: 1246 HloOpcode opcode_; 1247 bool invert_; 1248 }; 1249 1250 // An HloInstructionPattern implementation that matches only if the instruction 1251 // has the given number of operands. 1252 class HloInstructionPatternNumOperandsImpl { 1253 public: 1254 explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands) 1255 : num_operands_(num_operands) {} 1256 1257 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1258 if (inst->operand_count() != num_operands_) { 1259 EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands"; 1260 return false; 1261 } 1262 return true; 1263 } 1264 1265 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1266 *os << "with " << num_operands_ << " operand" 1267 << (num_operands_ != 1 ? "s" : ""); 1268 } 1269 1270 private: 1271 int64 num_operands_; 1272 }; 1273 1274 // An HloInstructionPattern implementation that matches only if the instruction 1275 // has a shape that matches a given pattern. 1276 template <typename ShapeType, typename ShapeImpl> 1277 class HloInstructionPatternShapeImpl { 1278 public: 1279 explicit constexpr HloInstructionPatternShapeImpl( 1280 const ShapePattern<ShapeType, ShapeImpl>& shape) 1281 : shape_(shape) {} 1282 1283 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1284 if (!shape_.Match(&inst->shape(), option)) { 1285 EXPLAIN << "\nin output shape"; 1286 return false; 1287 } 1288 return true; 1289 } 1290 1291 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1292 if (!shape_.Match(inst->mutable_shape(), option)) { 1293 EXPLAIN << "\nin output shape"; 1294 return false; 1295 } 1296 return true; 1297 } 1298 1299 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1300 *os << "outputting"; 1301 Indent(os, indent + kIndentInc); 1302 shape_.DescribeTo(os, indent + kIndentInc); 1303 } 1304 1305 private: 1306 ShapePattern<ShapeType, ShapeImpl> shape_; 1307 }; 1308 1309 // An HloInstructionPattern implementation that matches only if the instruction 1310 // has an operand that matches a given pattern. 1311 template <typename OperandType, typename OperandImpl> 1312 class HloInstructionPatternOperandImpl { 1313 public: 1314 explicit constexpr HloInstructionPatternOperandImpl( 1315 int64 operand_index, 1316 const HloInstructionPattern<OperandType, OperandImpl>& operand) 1317 : operand_index_(operand_index), operand_(operand) {} 1318 1319 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1320 return MatchImpl(inst, option); 1321 } 1322 1323 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1324 return MatchImpl(inst, option); 1325 } 1326 1327 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1328 *os << "with operand " << operand_index_ << " which is:"; 1329 Indent(os, indent + kIndentInc); 1330 operand_.DescribeTo(os, indent + kIndentInc); 1331 } 1332 1333 private: 1334 template <typename HloInstructionType> 1335 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1336 if (operand_index_ >= inst->operand_count()) { 1337 EXPLAIN << "desired operand index " << operand_index_ 1338 << " is out of bounds"; 1339 return false; 1340 } 1341 if (!operand_.Match(HloOperand(inst, operand_index_), option)) { 1342 EXPLAIN << "\nin operand " << operand_index_; 1343 return false; 1344 } 1345 return true; 1346 } 1347 1348 int64 operand_index_; 1349 HloInstructionPattern<OperandType, OperandImpl> operand_; 1350 }; 1351 1352 // Matches a binary instruction whose operands come in any order. 1353 template <typename OperandType1, typename OperandImpl1, typename OperandType2, 1354 typename OperandImpl2> 1355 class HloInstructionPatternBinaryOperandsAnyOrderImpl { 1356 public: 1357 explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl( 1358 const HloInstructionPattern<OperandType1, OperandImpl1>& op1, 1359 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) 1360 : op1_(op1), op2_(op2) {} 1361 1362 bool Match(HloInstruction* inst, MatchOption option) const { 1363 return MatchImpl(inst, option); 1364 } 1365 1366 bool Match(const HloInstruction* inst, MatchOption option) const { 1367 return MatchImpl(inst, option); 1368 } 1369 1370 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1371 *os << "with two operands in either order:"; 1372 Indent(os, indent); 1373 *os << " - "; 1374 op1_.DescribeTo(os, indent + 3); 1375 Indent(os, indent); 1376 *os << " - "; 1377 op2_.DescribeTo(os, indent + 3); 1378 } 1379 1380 private: 1381 HloInstruction* operand(HloInstruction* inst, int64 idx) const { 1382 return inst->mutable_operand(idx); 1383 } 1384 const HloInstruction* operand(const HloInstruction* inst, int64 idx) const { 1385 return inst->operand(idx); 1386 } 1387 1388 template <typename HloInstructionType> 1389 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1390 // We could implement this using AnyOf and AllOf matchers, but the templates 1391 // get pretty difficult to debug, since any compile error herein becomes 1392 // not-an-error via SFINAE. Also this way lets us give better messages on 1393 // failure. 1394 if (inst->operand_count() != 2) { 1395 EXPLAIN << "HloInstruction did not have two operands"; 1396 return false; 1397 } 1398 1399 // If we're not generating explanations, this is pretty simple. 1400 if (!option.explain_os) { 1401 auto try_match = [&](int64 idx1, int64 idx2) { 1402 MatchOption new_option = option; 1403 new_option.capture = false; 1404 if (op1_.Match(operand(inst, idx1), new_option) && 1405 op2_.Match(operand(inst, idx2), new_option)) { 1406 if (option.capture) { 1407 bool matched = op1_.Match(operand(inst, idx1), option) && 1408 op2_.Match(operand(inst, idx2), option); 1409 DCHECK(matched); 1410 } 1411 return true; 1412 } 1413 return false; 1414 }; 1415 return try_match(0, 1) || try_match(1, 0); 1416 } 1417 1418 // If we are generating explanations, we have some work to do in order to 1419 // generate a helpful error. 1420 // 1421 // First, try all four operand/matcher combinations, recording the 1422 // failure explanations separately from option.explain_os. matches[i][j] 1423 // tells us if matcher_i matches operand j. 1424 bool matches[/*matcher*/ 2][/*operand*/ 2]; 1425 std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2]; 1426 for (int i = 0; i < 2; ++i) { 1427 for (int j = 0; j < 2; ++j) { 1428 MatchOption new_option = option; 1429 new_option.capture = false; 1430 new_option.explain_os = &explanations[i][j]; 1431 matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option) 1432 : op2_.Match(operand(inst, j), new_option); 1433 } 1434 } 1435 1436 // Check if the match succeeded. 1437 for (int i = 0; i < 2; ++i) { 1438 if (matches[0][i] && matches[1][(i + 1) % 2]) { 1439 // Rerun the matches with capture enabled if necessary. 1440 if (option.capture) { 1441 auto* operand1 = operand(inst, i); 1442 auto* operand2 = operand(inst, (i + 1) % 2); 1443 bool matched = 1444 op1_.Match(operand1, option) && op2_.Match(operand2, option); 1445 DCHECK(matched); 1446 } 1447 return true; 1448 } 1449 } 1450 1451 auto describe_matcher = [&](int matcher_idx) { 1452 EXPLAIN << "\n - "; 1453 if (matcher_idx == 0) { 1454 op1_.DescribeTo(option.explain_os, /*indent=*/3); 1455 } else { 1456 CHECK_EQ(matcher_idx, 1); 1457 op2_.DescribeTo(option.explain_os, /*indent=*/3); 1458 } 1459 for (int i = 0; i < 2; ++i) { 1460 if (matches[matcher_idx][/*operand*/ i]) { 1461 continue; 1462 } 1463 EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n"; 1464 EXPLAIN << " - "; 1465 EXPLAIN << absl::StrReplaceAll( 1466 explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}}); 1467 } 1468 }; 1469 1470 // If we failed to match, one of the following is true: 1471 // 1. op1 (op2) matches neither LHS nor RHS, or 1472 // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS). 1473 // We print different explanations depending on which case we're in. 1474 1475 // Case 1. 1476 bool wrote_explanation = false; 1477 for (int i = 0; !wrote_explanation && i < 2; ++i) { 1478 if (!matches[i][0] && !matches[i][1]) { 1479 EXPLAIN << "HloInstruction's operands (ignoring order) did not match " 1480 << (i == 0 ? "first" : "second") << " matcher. Specifically,"; 1481 describe_matcher(i); 1482 wrote_explanation = true; 1483 } 1484 } 1485 1486 // Case 2. 1487 for (int i = 0; !wrote_explanation && i < 2; ++i) { 1488 if (matches[/*matcher*/ 0][/*operand*/ i] && 1489 matches[/*matcher*/ 1][/*operand*/ i]) { 1490 CHECK(!matches[0][(i + 1) % 2]); 1491 CHECK(!matches[1][(i + 1) % 2]); 1492 CHECK(!wrote_explanation); 1493 EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS") 1494 << " operand did not match either of the two matchers. " 1495 "Specifically,"; 1496 describe_matcher(0); 1497 EXPLAIN << "\nand"; 1498 describe_matcher(1); 1499 wrote_explanation = true; 1500 } 1501 } 1502 1503 CHECK(wrote_explanation); 1504 return false; 1505 } 1506 1507 HloInstructionPattern<OperandType1, OperandImpl1> op1_; 1508 HloInstructionPattern<OperandType2, OperandImpl2> op2_; 1509 }; 1510 1511 // An HloInstructionPattern implementation that matches only if the instruction 1512 // is a fusion node with a particular kind. 1513 class HloInstructionPatternFusionKindImpl { 1514 public: 1515 explicit constexpr HloInstructionPatternFusionKindImpl( 1516 ::xla::HloInstruction::FusionKind kind) 1517 : kind_(kind) {} 1518 1519 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1520 return MatchImpl(inst, option); 1521 } 1522 1523 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1524 return MatchImpl(inst, option); 1525 } 1526 1527 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1528 *os << "with fusion kind " << ToString(kind_); 1529 } 1530 1531 private: 1532 template <typename HloInstructionType> 1533 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1534 if (inst->opcode() != HloOpcode::kFusion) { 1535 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_) 1536 << "; it's not a fusion"; 1537 return false; 1538 } 1539 if (inst->fusion_kind() != kind_) { 1540 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_); 1541 return false; 1542 } 1543 return true; 1544 } 1545 1546 ::xla::HloInstruction::FusionKind kind_; 1547 }; 1548 1549 // An HloInstructionPattern implementation that matches only if the instruction 1550 // is a kGetTupleElement with a particular tuple index. 1551 class HloInstructionPatternTupleIndexImpl { 1552 public: 1553 explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index) 1554 : tuple_index_(tuple_index) {} 1555 1556 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1557 return MatchImpl(inst, option); 1558 } 1559 1560 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1561 return MatchImpl(inst, option); 1562 } 1563 1564 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1565 *os << "which is a GTE with index " << tuple_index_; 1566 } 1567 1568 private: 1569 template <typename HloInstructionType> 1570 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1571 if (inst->opcode() != HloOpcode::kGetTupleElement) { 1572 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_ 1573 << "; it's not a GTE at all"; 1574 return false; 1575 } 1576 if (inst->tuple_index() != tuple_index_) { 1577 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_; 1578 return false; 1579 } 1580 return true; 1581 } 1582 1583 int64 tuple_index_; 1584 }; 1585 1586 class HloInstructionPatternParameterNumImpl { 1587 public: 1588 explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num) 1589 : parameter_num_(parameter_num) {} 1590 1591 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1592 return MatchImpl(inst, option); 1593 } 1594 1595 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1596 return MatchImpl(inst, option); 1597 } 1598 1599 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1600 *os << "which is parameter " << parameter_num_; 1601 } 1602 1603 private: 1604 template <typename HloInstructionType> 1605 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1606 if (inst->opcode() != HloOpcode::kParameter || 1607 inst->parameter_number() != parameter_num_) { 1608 EXPLAIN << "HloInstruction is not parameter " << parameter_num_; 1609 return false; 1610 } 1611 return true; 1612 } 1613 1614 int64 parameter_num_; 1615 }; 1616 1617 // Superclass that contains common code used by Op::WithOneUse() and 1618 // Op::WithOneUser(). 1619 class HloInstructionPatternOneUseOrUserImpl { 1620 protected: 1621 bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { 1622 if (inst->user_count() != 1) { 1623 EXPLAIN << "HloInstruction has " << inst->user_count() 1624 << " users, but expected exactly one."; 1625 if (inst->user_count() > 1) { 1626 EXPLAIN << "\nAll users:"; 1627 for (const HloInstruction* user : inst->users()) { 1628 EXPLAIN << "\n - " << InstToString(user); 1629 } 1630 } 1631 return false; 1632 } 1633 return true; 1634 } 1635 }; 1636 1637 class HloInstructionPatternOneUseImpl 1638 : public HloInstructionPatternOneUseOrUserImpl { 1639 public: 1640 bool Match(const HloInstruction* inst, MatchOption option) const { 1641 if (!MatchOneUser(inst, option)) { 1642 return false; 1643 } 1644 1645 int64 use_count = absl::c_count_if( 1646 inst->users()[0]->operands(), 1647 [&](const HloInstruction* operand) { return operand == inst; }); 1648 if (use_count != 1) { 1649 EXPLAIN << "HloInstruction is used " << use_count 1650 << " times by its user, but is expected to be used just once: " 1651 << InstToString(inst->users()[0]); 1652 return false; 1653 } 1654 return true; 1655 } 1656 1657 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1658 *os << "which has exactly one use"; 1659 } 1660 }; 1661 1662 class HloInstructionPatternOneUserImpl 1663 : public HloInstructionPatternOneUseOrUserImpl { 1664 public: 1665 bool Match(const HloInstruction* inst, MatchOption option) const { 1666 return MatchOneUser(inst, option); 1667 } 1668 1669 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1670 *os << "which has exactly one user (but possibly is used multiple times by " 1671 "that instruction)"; 1672 } 1673 }; 1674 1675 class HloInstructionPatternComparisonDirectionImpl { 1676 public: 1677 explicit constexpr HloInstructionPatternComparisonDirectionImpl( 1678 ComparisonDirection direction) 1679 : direction_(direction) {} 1680 1681 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1682 return MatchImpl(inst, option); 1683 } 1684 1685 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1686 return MatchImpl(inst, option); 1687 } 1688 1689 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1690 *os << "which has comparison direction " 1691 << ComparisonDirectionToString(direction_); 1692 } 1693 1694 private: 1695 template <typename HloInstructionType> 1696 bool MatchImpl(HloInstructionType* inst, MatchOption option) const { 1697 if (inst->opcode() != HloOpcode::kCompare || 1698 inst->comparison_direction() != direction_) { 1699 EXPLAIN << "HloInstruction is not comparison " 1700 << ComparisonDirectionToString(direction_); 1701 return false; 1702 } 1703 return true; 1704 } 1705 1706 ComparisonDirection direction_; 1707 }; 1708 1709 // Matches a constant scalar or effective scalar, optionally with a given value. 1710 template <typename ScalarTy> 1711 class HloConstantScalarImpl { 1712 public: 1713 explicit constexpr HloConstantScalarImpl(bool match_effective_scalar) 1714 : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {} 1715 1716 constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar) 1717 : val_(val), match_effective_scalar_(match_effective_scalar) {} 1718 1719 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1720 return MatchImpl(inst, option); 1721 } 1722 1723 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1724 return MatchImpl(inst, option); 1725 } 1726 1727 void DescribeTo(std::ostream* os, int64 indent = 0) const { 1728 *os << "which is a constant " 1729 << (match_effective_scalar_ ? "effective " : "") << "scalar"; 1730 if (val_.has_value()) { 1731 *os << " with value " << *val_; 1732 } 1733 } 1734 1735 private: 1736 template <typename InstTy> 1737 bool MatchImpl(InstTy* inst, MatchOption option) const { 1738 const auto* const_inst = DynCast<HloConstantInstruction>(inst); 1739 if (!const_inst) { 1740 EXPLAIN << "HloInstruction is not a constant"; 1741 return false; 1742 } 1743 if (match_effective_scalar_ && 1744 !ShapeUtil::IsEffectiveScalar(inst->shape())) { 1745 EXPLAIN << "HloInstruction is not an effective scalar"; 1746 return false; 1747 } 1748 if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) { 1749 EXPLAIN << "HloInstruction is not a scalar"; 1750 return false; 1751 } 1752 if (!val_.has_value()) { 1753 return true; 1754 } 1755 1756 // Check that literal == static_cast<LitearlTy>(val) and 1757 // val == static_cast<ValTy>(literal). This is sufficient to ensure that 1758 // the two constant scalars are actually "equal". 1759 auto val_literal = LiteralUtil::CreateR0(*val_); 1760 auto literal_r0_or = const_inst->literal().Reshape({}); 1761 auto val_as_literal_ty_or = 1762 val_literal.Convert(const_inst->shape().element_type()); 1763 if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) { 1764 EXPLAIN << "could not construct relevant Literals (how did this happen?)"; 1765 return false; 1766 } 1767 auto literal_r0 = std::move(literal_r0_or).ValueOrDie(); 1768 auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie(); 1769 auto literal_r0_as_val_ty_or = 1770 literal_r0.Convert(val_literal.shape().element_type()); 1771 bool rv = literal_r0_as_val_ty_or.ok() && // 1772 literal_r0_as_val_ty_or.ValueOrDie() == val_literal && 1773 literal_r0 == val_as_literal_ty; 1774 if (!rv) { 1775 EXPLAIN << "HloInstruction's constant value " 1776 << literal_r0.ToStringWithoutShape() 1777 << " did not match expected value " << *val_; 1778 } 1779 return rv; 1780 } 1781 1782 absl::optional<ScalarTy> val_; 1783 bool match_effective_scalar_; 1784 }; 1785 1786 // A pattern that matches HloInstructions. 1787 template <typename HloInstructionType, typename Impl> 1788 class HloInstructionPattern { 1789 private: 1790 template <typename NewImpl> 1791 auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< 1792 HloInstructionType, decltype(AllOf<HloInstruction>( 1793 std::declval<Impl>(), std::move(new_impl)))> { 1794 auto new_allof = AllOf<HloInstruction>(impl_, std::move(new_impl)); 1795 return HloInstructionPattern<HloInstructionType, decltype(new_allof)>( 1796 std::move(new_allof), matched_inst_); 1797 } 1798 1799 public: 1800 explicit constexpr HloInstructionPattern(const Impl& impl, 1801 HloInstructionType** matched_inst) 1802 : impl_(impl), matched_inst_(matched_inst) {} 1803 1804 // Returns true and captures the instruction iff it matches the pattern. 1805 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { 1806 if (impl_.Match(inst, option)) { 1807 if (option.capture && matched_inst_) { 1808 *matched_inst_ = inst; 1809 } 1810 return true; 1811 } 1812 if (inst != nullptr) { 1813 EXPLAIN << "\nin " << InstToString(inst); 1814 } 1815 return false; 1816 } 1817 1818 // Returns true and captures the instruction iff it matches the pattern. 1819 bool Match(::xla::HloInstruction* inst, MatchOption option) const { 1820 if (impl_.Match(inst, option)) { 1821 if (option.capture && matched_inst_) { 1822 *matched_inst_ = inst; 1823 } 1824 return true; 1825 } 1826 EXPLAIN << "\nin " << InstToString(inst); 1827 return false; 1828 } 1829 1830 // Modifies the pattern to match only if the instruction has the given name. 1831 auto WithName(absl::string_view name) const 1832 -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { 1833 return AppendImpl(HloInstructionPatternNameImpl(name)); 1834 } 1835 1836 // Modifies the pattern to match only if the instruction has the given opcode. 1837 auto WithOpcode(HloOpcode opcode) const 1838 -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, 1839 false))) { 1840 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); 1841 } 1842 1843 auto WithNumOperands(int64 num_operands) const -> decltype( 1844 this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) { 1845 return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands)); 1846 } 1847 1848 // Modifies the pattern to match only if the instruction does not have the 1849 // given opcode. 1850 auto WithoutOpcode(HloOpcode opcode) const 1851 -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, 1852 true))) { 1853 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); 1854 } 1855 1856 constexpr auto Is(const HloInstruction* instr) const 1857 -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { 1858 return AppendImpl(HloInstructionIsImpl(instr)); 1859 } 1860 1861 // Modifies the pattern to match only if the instruction is a constant. 1862 constexpr auto IsConstant() const 1863 -> decltype(this->WithOpcode(HloOpcode::kConstant)) { 1864 return WithOpcode(HloOpcode::kConstant); 1865 } 1866 1867 constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( 1868 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false))) { 1869 return AppendImpl( 1870 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false)); 1871 } 1872 1873 // This does not check that T has the same type as the instruction, so e.g. 1874 // IsConstantScalar(1.0) may match a constant of shape int32[]. 1875 template <typename ScalarTy> 1876 constexpr auto IsConstantScalar(const ScalarTy& val) const 1877 -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>( 1878 val, /*match_effective_scalar=*/false))) { 1879 return AppendImpl( 1880 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false)); 1881 } 1882 1883 constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( 1884 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true))) { 1885 return AppendImpl( 1886 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true)); 1887 } 1888 1889 template <typename ScalarTy> 1890 constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const 1891 -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>( 1892 val, /*match_effective_scalar=*/true))) { 1893 return AppendImpl( 1894 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true)); 1895 } 1896 1897 // Modifies the pattern to match only if the instruction is not a constant. 1898 constexpr auto IsNonConstant() const 1899 -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { 1900 return WithoutOpcode(HloOpcode::kConstant); 1901 } 1902 1903 // Modifies the pattern to match only if the instruction has a shape that 1904 // matches the given pattern. 1905 template <typename ShapeType, typename ShapeImpl> 1906 constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) 1907 const -> decltype(this->AppendImpl( 1908 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) { 1909 return AppendImpl( 1910 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape)); 1911 } 1912 1913 // Make this a templated function to work around gcc 4.9.4 template infinite 1914 // recursion bug. 1915 template <typename Dummy = void> 1916 constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const 1917 -> decltype(this->WithShape(Shape().EqualTo(shape))) { 1918 return WithShape(Shape().EqualTo(shape)); 1919 } 1920 1921 // Make this a templated function to work around gcc 4.9.4 template infinite 1922 // recursion bug. 1923 template <typename Dummy = void> 1924 constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const 1925 -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { 1926 return WithShape(Shape().CompatibleTo(shape)); 1927 } 1928 1929 // Modifies the pattern to match only if the instruction has an operand that 1930 // matches the given pattern. 1931 template <typename OperandType, typename OperandImpl> 1932 constexpr auto WithOperand( 1933 int64 operand_index, 1934 const HloInstructionPattern<OperandType, OperandImpl>& operand) const 1935 -> decltype(this->AppendImpl( 1936 HloInstructionPatternOperandImpl<OperandType, OperandImpl>( 1937 operand_index, operand))) { 1938 return AppendImpl( 1939 HloInstructionPatternOperandImpl<OperandType, OperandImpl>( 1940 operand_index, operand)); 1941 } 1942 1943 template <typename OperandType1, typename OperandImpl1, typename OperandType2, 1944 typename OperandImpl2> 1945 constexpr auto WithBinaryOperandsAnyOrder( 1946 const HloInstructionPattern<OperandType1, OperandImpl1>& op1, 1947 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const 1948 -> decltype(this->AppendImpl( 1949 HloInstructionPatternBinaryOperandsAnyOrderImpl< 1950 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, 1951 op2))) { 1952 return AppendImpl( 1953 HloInstructionPatternBinaryOperandsAnyOrderImpl< 1954 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); 1955 } 1956 1957 // Modifies the pattern to match only if the instruction is a fusion node with 1958 // the given kind. 1959 constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const 1960 -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { 1961 return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); 1962 } 1963 1964 // Modifies the pattern to match only if the instruction is a 1965 // get-tuple-element with the given tuple index. 1966 constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( 1967 this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { 1968 return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); 1969 } 1970 1971 // Modifies the pattern to match only if the instruction is a parameter 1972 // with the given parameter number. 1973 constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( 1974 this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { 1975 return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); 1976 } 1977 1978 // Modifies the pattern to match if the instruction is used exactly once. 1979 // Does not match if the instruction is used twice by the same user (e.g. 1980 // multiply(x,x)). 1981 constexpr auto WithOneUse() const 1982 -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { 1983 return AppendImpl(HloInstructionPatternOneUseImpl()); 1984 } 1985 1986 // Modifies the pattern to match if the instruction is used by exactly one 1987 // other instruction. Will match if the instruction is used twice, so long as 1988 // it's by the same user (e.g. multiply(x,x)). 1989 constexpr auto WithOneUser() const 1990 -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { 1991 return AppendImpl(HloInstructionPatternOneUserImpl()); 1992 } 1993 1994 // Modifies the pattern to match only if the instruction has the given 1995 // comparison direction. 1996 auto WithComparisonDirection(ComparisonDirection direction) const 1997 -> decltype(this->AppendImpl( 1998 HloInstructionPatternComparisonDirectionImpl(direction))) { 1999 return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); 2000 } 2001 2002 void DescribeTo(std::ostream* os, int64 indent = 0) const { 2003 impl_.DescribeTo(os, indent); 2004 } 2005 2006 private: 2007 Impl impl_; 2008 HloInstructionType** matched_inst_; 2009 }; 2010 2011 } // namespace detail 2012 2013 // Creates an instruction pattern that will capture the matched instruction in 2014 // the argument. 2015 inline constexpr detail::HloInstructionPattern< 2016 const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> 2017 Op(const ::xla::HloInstruction** matched_inst = nullptr) { 2018 return detail::HloInstructionPattern<const ::xla::HloInstruction, 2019 detail::HloInstructionPatternBaseImpl>( 2020 detail::HloInstructionPatternBaseImpl(), matched_inst); 2021 } 2022 2023 // Creates an instruction pattern that will capture the matched instruction in 2024 // the argument. 2025 inline constexpr detail::HloInstructionPattern< 2026 ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> 2027 Op(::xla::HloInstruction** matched_inst) { 2028 return detail::HloInstructionPattern<::xla::HloInstruction, 2029 detail::HloInstructionPatternBaseImpl>( 2030 detail::HloInstructionPatternBaseImpl(), matched_inst); 2031 } 2032 2033 // Helpers for nullary instructions. 2034 #define XLA_NULLOP_PATTERN(NAME) \ 2035 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ 2036 return Op().WithOpcode(HloOpcode::k##NAME); \ 2037 } \ 2038 \ 2039 template <typename HloInstructionType> \ 2040 inline auto NAME(HloInstructionType** matched_inst) \ 2041 ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ 2042 return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ 2043 } 2044 XLA_NULLOP_PATTERN(Constant) 2045 XLA_NULLOP_PATTERN(Parameter) 2046 XLA_NULLOP_PATTERN(Iota) 2047 XLA_NULLOP_PATTERN(Rng) 2048 #undef XLA_NULLOP_PATTERN 2049 2050 // Helpers for unary instructions. 2051 #define XLA_UNOP_PATTERN(NAME) \ 2052 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ 2053 return Op().WithOpcode(HloOpcode::k##NAME); \ 2054 } \ 2055 \ 2056 template <typename Arg> \ 2057 inline auto NAME(Arg&& arg)->decltype( \ 2058 Op().WithOpcode(HloOpcode::k##NAME) \ 2059 .WithOperand(0, std::forward<Arg>(arg))) { \ 2060 return Op() \ 2061 .WithOpcode(HloOpcode::k##NAME) \ 2062 .WithOperand(0, std::forward<Arg>(arg)); \ 2063 } \ 2064 \ 2065 template <typename HloInstructionType, typename Arg> \ 2066 inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ 2067 ->decltype(Op(matched_inst) \ 2068 .WithOpcode(HloOpcode::k##NAME) \ 2069 .WithOperand(0, std::forward<Arg>(arg))) { \ 2070 return Op(matched_inst) \ 2071 .WithOpcode(HloOpcode::k##NAME) \ 2072 .WithOperand(0, std::forward<Arg>(arg)); \ 2073 } 2074 XLA_UNOP_PATTERN(Abs) 2075 XLA_UNOP_PATTERN(RoundNearestAfz) 2076 XLA_UNOP_PATTERN(Bitcast) 2077 XLA_UNOP_PATTERN(Broadcast) 2078 XLA_UNOP_PATTERN(Ceil) 2079 XLA_UNOP_PATTERN(Convert) 2080 XLA_UNOP_PATTERN(Copy) 2081 XLA_UNOP_PATTERN(Cos) 2082 XLA_UNOP_PATTERN(AllReduce) 2083 XLA_UNOP_PATTERN(Exp) 2084 XLA_UNOP_PATTERN(Fft) 2085 XLA_UNOP_PATTERN(Floor) 2086 XLA_UNOP_PATTERN(GetTupleElement) 2087 XLA_UNOP_PATTERN(Imag) 2088 XLA_UNOP_PATTERN(Infeed) 2089 XLA_UNOP_PATTERN(IsFinite) 2090 XLA_UNOP_PATTERN(Log) 2091 XLA_UNOP_PATTERN(Not) 2092 XLA_UNOP_PATTERN(Negate) 2093 XLA_UNOP_PATTERN(Real) 2094 XLA_UNOP_PATTERN(Recv) 2095 XLA_UNOP_PATTERN(RecvDone) 2096 XLA_UNOP_PATTERN(ReducePrecision) 2097 XLA_UNOP_PATTERN(Reshape) 2098 XLA_UNOP_PATTERN(Reverse) 2099 XLA_UNOP_PATTERN(Rsqrt) 2100 XLA_UNOP_PATTERN(SendDone) 2101 XLA_UNOP_PATTERN(Sign) 2102 XLA_UNOP_PATTERN(Sin) 2103 XLA_UNOP_PATTERN(Slice) 2104 XLA_UNOP_PATTERN(Sqrt) 2105 XLA_UNOP_PATTERN(Tanh) 2106 XLA_UNOP_PATTERN(Transpose) 2107 #undef XLA_UNOP_PATTERN 2108 2109 // Helpers for binary instructions. 2110 #define XLA_BINOP_PATTERN(NAME) \ 2111 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ 2112 return Op().WithOpcode(HloOpcode::k##NAME); \ 2113 } \ 2114 \ 2115 template <typename Lhs, typename Rhs> \ 2116 inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ 2117 ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ 2118 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2119 .WithOperand(1, std::forward<Rhs>(rhs))) { \ 2120 return Op() \ 2121 .WithOpcode(HloOpcode::k##NAME) \ 2122 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2123 .WithOperand(1, std::forward<Rhs>(rhs)); \ 2124 } \ 2125 \ 2126 template <typename HloInstructionType, typename Lhs, typename Rhs> \ 2127 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ 2128 ->decltype(Op(matched_inst) \ 2129 .WithOpcode(HloOpcode::k##NAME) \ 2130 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2131 .WithOperand(1, std::forward<Rhs>(rhs))) { \ 2132 return Op(matched_inst) \ 2133 .WithOpcode(HloOpcode::k##NAME) \ 2134 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2135 .WithOperand(1, std::forward<Rhs>(rhs)); \ 2136 } 2137 2138 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ 2139 XLA_BINOP_PATTERN(NAME) \ 2140 \ 2141 template <typename HloInstructionType, typename Lhs, typename Rhs> \ 2142 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ 2143 Rhs&& rhs) \ 2144 ->decltype(Op(matched_inst) \ 2145 .WithOpcode(HloOpcode::k##NAME) \ 2146 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \ 2147 std::forward<Rhs>(rhs))) { \ 2148 return Op(matched_inst) \ 2149 .WithOpcode(HloOpcode::k##NAME) \ 2150 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \ 2151 std::forward<Rhs>(rhs)); \ 2152 } \ 2153 template <typename Lhs, typename Rhs> \ 2154 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ 2155 ->decltype(NAME##AnyOrder<const HloInstruction>( \ 2156 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \ 2157 return NAME##AnyOrder<const HloInstruction>( \ 2158 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \ 2159 } 2160 XLA_COMMUTATIVE_BINOP_PATTERN(Add) 2161 XLA_BINOP_PATTERN(Atan2) 2162 XLA_BINOP_PATTERN(Divide) 2163 XLA_BINOP_PATTERN(Complex) 2164 XLA_BINOP_PATTERN(Compare) 2165 XLA_BINOP_PATTERN(Convolution) 2166 XLA_BINOP_PATTERN(Dot) 2167 XLA_BINOP_PATTERN(Gather) 2168 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) 2169 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) 2170 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) 2171 XLA_BINOP_PATTERN(Outfeed) 2172 XLA_BINOP_PATTERN(Pad) 2173 XLA_BINOP_PATTERN(Power) 2174 XLA_BINOP_PATTERN(ReduceWindow) 2175 XLA_BINOP_PATTERN(Remainder) 2176 XLA_BINOP_PATTERN(Send) 2177 XLA_BINOP_PATTERN(Subtract) 2178 XLA_COMMUTATIVE_BINOP_PATTERN(And) 2179 XLA_COMMUTATIVE_BINOP_PATTERN(Or) 2180 XLA_BINOP_PATTERN(ShiftLeft) 2181 XLA_BINOP_PATTERN(ShiftRightArithmetic) 2182 XLA_BINOP_PATTERN(ShiftRightLogical) 2183 #undef XLA_COMMUTATIVE_BINOP_PATTERN 2184 #undef XLA_BINOP_PATTERN 2185 2186 // Helpers for ternary instructions. 2187 #define XLA_TERNOP_PATTERN(NAME) \ 2188 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ 2189 return Op().WithOpcode(HloOpcode::k##NAME); \ 2190 } \ 2191 \ 2192 template <typename Arg0, typename Arg1, typename Arg2> \ 2193 inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ 2194 ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ 2195 .WithOperand(0, std::forward<Arg0>(arg0)) \ 2196 .WithOperand(1, std::forward<Arg1>(arg1)) \ 2197 .WithOperand(2, std::forward<Arg2>(arg2))) { \ 2198 return Op() \ 2199 .WithOpcode(HloOpcode::k##NAME) \ 2200 .WithOperand(0, std::forward<Arg0>(arg0)) \ 2201 .WithOperand(1, std::forward<Arg1>(arg1)) \ 2202 .WithOperand(2, std::forward<Arg2>(arg2)); \ 2203 } \ 2204 \ 2205 template <typename HloInstructionType, typename Arg0, typename Arg1, \ 2206 typename Arg2> \ 2207 inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ 2208 Arg1&& arg1, Arg2&& arg2) \ 2209 ->decltype(Op(matched_inst) \ 2210 .WithOpcode(HloOpcode::k##NAME) \ 2211 .WithOperand(0, std::forward<Arg0>(arg0)) \ 2212 .WithOperand(1, std::forward<Arg1>(arg1)) \ 2213 .WithOperand(2, std::forward<Arg2>(arg2))) { \ 2214 return Op(matched_inst) \ 2215 .WithOpcode(HloOpcode::k##NAME) \ 2216 .WithOperand(0, std::forward<Arg0>(arg0)) \ 2217 .WithOperand(1, std::forward<Arg1>(arg1)) \ 2218 .WithOperand(2, std::forward<Arg2>(arg2)); \ 2219 } 2220 XLA_TERNOP_PATTERN(Clamp); 2221 XLA_TERNOP_PATTERN(Scatter); 2222 XLA_TERNOP_PATTERN(Select); 2223 #undef XLA_TERNOP_PATTERN 2224 2225 namespace detail { 2226 template <typename Matcher, typename FirstArg> 2227 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) 2228 -> decltype(m.WithOperand(operand_num, std::forward<FirstArg>(first_arg))) { 2229 return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)); 2230 } 2231 2232 template <typename Matcher, typename FirstArg, typename... Args> 2233 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, 2234 Args&&... args) 2235 -> decltype(WithOperands(m.WithOperand(operand_num, 2236 std::forward<FirstArg>(first_arg)), 2237 operand_num + 1, std::forward<Args>(args)...)) { 2238 return WithOperands( 2239 m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)), 2240 operand_num + 1, std::forward<Args>(args)...); 2241 } 2242 } // namespace detail 2243 2244 #define XLA_VARIADIC_OP_PATTERN(NAME) \ 2245 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ 2246 return Op().WithOpcode(HloOpcode::k##NAME); \ 2247 } \ 2248 \ 2249 template <typename... Args> \ 2250 inline auto NAME(Args&&... args) \ 2251 ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \ 2252 .WithNumOperands(sizeof...(Args)), \ 2253 0, std::forward<Args>(args)...)) { \ 2254 return detail::WithOperands( \ 2255 Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \ 2256 /*operand_num=*/0, std::forward<Args>(args)...); \ 2257 } \ 2258 \ 2259 template <typename HloInstructionType, typename... Args> \ 2260 inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \ 2261 ->decltype(detail::WithOperands(Op(matched_inst) \ 2262 .WithOpcode(HloOpcode::k##NAME) \ 2263 .WithNumOperands(sizeof...(Args)), \ 2264 0, std::forward<Args>(args)...)) { \ 2265 return detail::WithOperands(Op(matched_inst) \ 2266 .WithOpcode(HloOpcode::k##NAME) \ 2267 .WithNumOperands(sizeof...(Args)), \ 2268 /*operand_num=*/0, \ 2269 std::forward<Args>(args)...); \ 2270 } 2271 2272 // We could implement all ops as "variadic" ops, but it would make the 2273 // already-bad compile errors even worse. 2274 XLA_VARIADIC_OP_PATTERN(AfterAll); 2275 XLA_VARIADIC_OP_PATTERN(Concatenate); 2276 XLA_VARIADIC_OP_PATTERN(CustomCall); 2277 XLA_VARIADIC_OP_PATTERN(DynamicSlice) 2278 XLA_VARIADIC_OP_PATTERN(Map) 2279 XLA_VARIADIC_OP_PATTERN(Reduce); 2280 XLA_VARIADIC_OP_PATTERN(Sort); 2281 XLA_VARIADIC_OP_PATTERN(Tuple); 2282 2283 // Helpers for comparison instructions. 2284 #define XLA_COMPARE_PATTERN(NAME) \ 2285 inline auto NAME()->decltype( \ 2286 Op().WithOpcode(HloOpcode::kCompare) \ 2287 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ 2288 return Op() \ 2289 .WithOpcode(HloOpcode::kCompare) \ 2290 .WithComparisonDirection(ComparisonDirection::k##NAME); \ 2291 } \ 2292 \ 2293 template <typename Lhs, typename Rhs> \ 2294 inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ 2295 ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ 2296 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2297 .WithOperand(1, std::forward<Rhs>(rhs)) \ 2298 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ 2299 return Op() \ 2300 .WithOpcode(HloOpcode::kCompare) \ 2301 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2302 .WithOperand(1, std::forward<Rhs>(rhs)) \ 2303 .WithComparisonDirection(ComparisonDirection::k##NAME); \ 2304 } \ 2305 \ 2306 template <typename HloInstructionType, typename Lhs, typename Rhs> \ 2307 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ 2308 ->decltype(Op(matched_inst) \ 2309 .WithOpcode(HloOpcode::kCompare) \ 2310 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2311 .WithOperand(1, std::forward<Rhs>(rhs)) \ 2312 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ 2313 return Op(matched_inst) \ 2314 .WithOpcode(HloOpcode::kCompare) \ 2315 .WithOperand(0, std::forward<Lhs>(lhs)) \ 2316 .WithOperand(1, std::forward<Rhs>(rhs)) \ 2317 .WithComparisonDirection(ComparisonDirection::k##NAME); \ 2318 } 2319 2320 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ 2321 XLA_COMPARE_PATTERN(NAME) \ 2322 \ 2323 template <typename HloInstructionType, typename Lhs, typename Rhs> \ 2324 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ 2325 Rhs&& rhs) \ 2326 ->decltype(Op(matched_inst) \ 2327 .WithOpcode(HloOpcode::kCompare) \ 2328 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \ 2329 std::forward<Rhs>(rhs))) { \ 2330 return Op(matched_inst) \ 2331 .WithOpcode(HloOpcode::kCompare) \ 2332 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \ 2333 std::forward<Rhs>(rhs)); \ 2334 } \ 2335 template <typename Lhs, typename Rhs> \ 2336 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ 2337 ->decltype(NAME##AnyOrder<const HloInstruction>( \ 2338 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \ 2339 return NAME##AnyOrder<const HloInstruction>( \ 2340 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \ 2341 } 2342 2343 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); 2344 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne); 2345 XLA_COMPARE_PATTERN(Ge); 2346 XLA_COMPARE_PATTERN(Gt); 2347 XLA_COMPARE_PATTERN(Le); 2348 XLA_COMPARE_PATTERN(Lt); 2349 2350 // Helpers for matching non-constant instructions. 2351 inline auto NonConstant() -> decltype(Op().IsNonConstant()) { 2352 return Op().IsNonConstant(); 2353 } 2354 2355 template <typename HloInstructionType> 2356 inline auto NonConstant(HloInstructionType** matched_inst) 2357 -> decltype(Op(matched_inst).IsNonConstant()) { 2358 return Op(matched_inst).IsNonConstant(); 2359 } 2360 2361 // Add overloads for GetTupleElement which take a int64 specifying which tuple 2362 // element is selected. 2363 template <typename Arg> 2364 inline auto GetTupleElement(Arg&& arg, int64 tuple_index) 2365 -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) 2366 .WithOperand(0, std::forward<Arg>(arg)) 2367 .WithTupleIndex(tuple_index)) { 2368 return Op() 2369 .WithOpcode(HloOpcode::kGetTupleElement) 2370 .WithOperand(0, std::forward<Arg>(arg)) 2371 .WithTupleIndex(tuple_index); 2372 } 2373 2374 template <typename HloInstructionType, typename Arg> 2375 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, 2376 int64 tuple_index) 2377 -> decltype(Op(matched_inst) 2378 .WithOpcode(HloOpcode::kGetTupleElement) 2379 .WithOperand(0, std::forward<Arg>(arg)) 2380 .WithTupleIndex(tuple_index)) { 2381 return Op(matched_inst) 2382 .WithOpcode(HloOpcode::kGetTupleElement) 2383 .WithOperand(0, std::forward<Arg>(arg)) 2384 .WithTupleIndex(tuple_index); 2385 } 2386 2387 // Add overloads for Parameter which take an int64 specifying the parameter 2388 // number. 2389 inline auto Parameter(int64 parameter_num) -> decltype( 2390 Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { 2391 return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); 2392 } 2393 template <typename HloInstructionType> 2394 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) 2395 -> decltype(Op(matched_inst) 2396 .WithOpcode(HloOpcode::kParameter) 2397 .WithParameterNum(parameter_num)) { 2398 return Op(matched_inst) 2399 .WithOpcode(HloOpcode::kParameter) 2400 .WithParameterNum(parameter_num); 2401 } 2402 2403 inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { 2404 return Op().IsConstantScalar(); 2405 } 2406 2407 template <typename HloInstructionType> 2408 inline auto ConstantScalar(HloInstructionType** matched_inst) 2409 -> decltype(Op(matched_inst).IsConstantScalar()) { 2410 return Op(matched_inst).IsConstantScalar(); 2411 } 2412 2413 template <typename ScalarTy> 2414 inline auto ConstantScalar(ScalarTy val) 2415 -> decltype(Op().IsConstantScalar(val)) { 2416 return Op().IsConstantScalar(val); 2417 } 2418 2419 template <typename HloInstructionType, typename ScalarTy> 2420 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) 2421 -> decltype(Op(matched_inst).IsConstantScalar(val)) { 2422 return Op(matched_inst).IsConstantScalar(val); 2423 } 2424 2425 inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { 2426 return Op().IsConstantEffectiveScalar(); 2427 } 2428 2429 template <typename HloInstructionType> 2430 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) 2431 -> decltype(Op(matched_inst).IsConstantScalar()) { 2432 return Op(matched_inst).IsConstantEffectiveScalar(); 2433 } 2434 2435 template <typename ScalarTy> 2436 inline auto ConstantEffectiveScalar(ScalarTy val) 2437 -> decltype(Op().IsConstantEffectiveScalar(val)) { 2438 return Op().IsConstantEffectiveScalar(val); 2439 } 2440 2441 template <typename HloInstructionType, typename ScalarTy> 2442 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, 2443 ScalarTy val) 2444 -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { 2445 return Op(matched_inst).IsConstantEffectiveScalar(val); 2446 } 2447 2448 } // namespace match 2449 2450 } // namespace xla 2451 2452 #undef EXPLAIN 2453 #pragma pop_macro("EXPLAIN") 2454 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ 2455