Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
     18 
     19 #include "absl/strings/str_replace.h"
     20 #include "absl/strings/string_view.h"
     21 #include "absl/utility/utility.h"
     22 #include "tensorflow/compiler/xla/layout_util.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
     27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 
     30 namespace xla {
     31 
     32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
     33 //
     34 // The Match function's first argument must be HloInstruction*, Shape*, or
     35 // Layout*. The second argument is a pattern that will be matched against the
     36 // first argument, as described below.
     37 //
     38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
     39 // functions. By default, the returned patterns will match any HloInstruction,
     40 // Shape, or Layout, respectively. However the match can be made more specific
     41 // by using the pattern's modifier methods, for example:
     42 //
     43 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
     44 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
     45 //
     46 // This pattern will match Add instructions whose first operand is a constant.
     47 //
     48 // Each pattern type has the following modifiers, which are described where
     49 // nontrivial.
     50 //
     51 //   Op():
     52 //     - Is: is the given HloInstruction* (i.e. pointer equality)
     53 //     - WithName
     54 //     - WithOpcode
     55 //     - WithoutOpcode: anything other than the given opcode
     56 //     - WithShape: instr's shape matches the given pattern
     57 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
     58 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
     59 //     - WithNumOperands
     60 //     - WithOperand: operand at the given index matches the given pattern
     61 //     - IsConstant
     62 //     - IsNonConstant
     63 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
     64 //       e.g. IsConstantScalar() or IsConstantScalar(42).
     65 //     - WithFusionKind
     66 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
     67 //     - WithOneUse: Instruction is used as an operand exactly once.
     68 //     - WithOneUser: Instruction is used by exactly one other instruction, but
     69 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
     70 //     - WithComparisonDirection: instr has the given direction
     71 //
     72 //   Shape():
     73 //     - EqualTo
     74 //     - CompatibleTo
     75 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
     76 //     - IsDenseArray/IsSparseArray
     77 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
     78 //       Layout().WithDenseFormat())
     79 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
     80 //       Layout, but not the result of Layout().foo())
     81 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
     82 //       (e.g. Shape().IsScalar()).
     83 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
     84 //       (i.e. another Shape, but not the result of Shape().foo())
     85 //     - WithElementType: shape is an array/scalar with the given elem type
     86 //     - WithRank: shape is an array/scalar with the given rank
     87 //
     88 //  Layout():
     89 //     - EqualTo
     90 //     - WithDenseFormat/WithSparseFormat
     91 //
     92 // Op(), Shape(), and Layout() may be passed an argument of type
     93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
     94 // these pointers. If the pattern is matched, the address of the matched value
     95 // will be "captured" and stored at this location.
     96 //
     97 // For example:
     98 //   HloInstruction* foo = ...;
     99 //   HloInstruction* matched_operand;
    100 //   CHECK(Match(foo,
    101 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
    102 //
    103 // Helpers are provided for most HLO instructions. These helpers can be called
    104 // with no arguments, in which case they will match any instruction matching the
    105 // opcode. They may also be called with matches for the operands and with an
    106 // optional capture. (The capture must be the first argument.) Some examples of
    107 // these helpers and their equivalents are provided below.
    108 
    109 // Example nullary instruction:
    110 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
    111 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
    112 //
    113 // Example unary instruction:
    114 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
    115 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
    116 //                                         .WithOperand(0, Op(&a)))
    117 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
    118 //                                           .WithOperand(0, Op(&b))
    119 //
    120 // Commutative binary instructions have a special form that accepts either order
    121 // of args, e.g.:
    122 //
    123 //   AddAnyOrder(Parameter(1), Abs()) ==
    124 //     Op().WithOpcode(HloOpcode::kAdd)
    125 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
    126 //
    127 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
    128 //
    129 // The following additional helpers are provided.  In all cases, `&a` is
    130 // optional.
    131 //
    132 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
    133 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
    134 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
    135 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
    136 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
    137 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
    138 //                                             .WithOperand(0, b);
    139 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
    140 
    141 struct MatchOption {
    142   // If true, actually capture matched item into the user pointer.
    143   bool capture;
    144 
    145   // An explanation for why we failed to match is streamed here, if not-null.
    146   std::ostream* explain_os;
    147 };
    148 
    149 template <typename Value, typename Pattern>
    150 bool Match(Value* value, const Pattern& pattern,
    151            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
    152   if (option.capture) {
    153     auto new_option = option;
    154     new_option.capture = false;
    155     if (!pattern.Match(value, new_option)) {
    156       return false;
    157     }
    158   }
    159   return pattern.Match(value, option);
    160 }
    161 
    162 namespace match {
    163 
    164 namespace detail {
    165 
    166 // Macro for streaming to option.explain_os if it's not null.
    167 //
    168 //   EXPLAIN << "value of foo(): " << foo()
    169 //
    170 #pragma push_macro("EXPLAIN")
    171 #define EXPLAIN \
    172   if (option.explain_os) *option.explain_os
    173 
    174 // kIndentInc is the additional number of spaces that we indent by when we
    175 // increase the indent "by one".
    176 enum {
    177   kIndentInc = 2,
    178 };
    179 
    180 // Writes a newline and then `indent` spaces.
    181 //
    182 // We follow an unintuitive convention in this file's pretty-printers: Indents
    183 // are performed by the caller, not the callee.  For example, if you want to
    184 // print
    185 //
    186 //   foo:
    187 //    - bar
    188 //
    189 // you'd do:
    190 //
    191 //  Foo::DescribeTo(std::ostream* os, int64 indent) {
    192 //    *os << "foo:";
    193 //    Indent(os, indent)  // Create a newline at the *current* indent level.
    194 //    *os << " - ";
    195 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
    196 //  }
    197 //
    198 //  Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
    199 //
    200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
    201 // performed by Foo.  This convention allows the caller to decide whether a
    202 // matcher is preceded by a newline, which is important e.g. for the AllOf
    203 // matcher.
    204 //
    205 // (Incidentally, indenting in Match's explanations is handled differently.
    206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
    207 // they're a special case in Match [we're printing only a path through the tree
    208 // that encounters a failing node]. Indents in Match only appear when we
    209 // encounter a failing disjunction, so we just handle them as a special case
    210 // there.)
    211 inline void Indent(std::ostream* os, int64 indent) {
    212   *os << "\n";
    213   for (int64 i = 0; i < indent; ++i) {
    214     *os << " ";
    215   }
    216 }
    217 
    218 // SFINAE template that determines whether T declares a static member
    219 // kIsTrivialMatcher.
    220 //
    221 // Trivial matchers get special treatment.  For example, when printing
    222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
    223 // yields e.g.
    224 //    "a shape compatible with f32[1,2]"
    225 // rather than
    226 //    "a shape AND compatible with f32[1,2]"
    227 template <typename T, typename Dummy = void>
    228 struct IsTrivialMatcher {
    229   static constexpr bool value = false;
    230 };
    231 template <typename T>
    232 struct IsTrivialMatcher<T,
    233                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
    234   static constexpr bool value = true;
    235 };
    236 
    237 template <typename Item, typename... Patterns>
    238 class AllOfPattern {
    239  public:
    240   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
    241 
    242   bool Match(const Item* item, MatchOption option) const {
    243     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
    244     // This invariant is guaranteed by the top-level Match and AnyOf.
    245     DCHECK(matched || !option.capture);
    246     return matched;
    247   }
    248 
    249   bool Match(Item* item, MatchOption option) const {
    250     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
    251     // This invariant is guaranteed by the top-level Match and AnyOf.
    252     DCHECK(matched || !option.capture);
    253     return matched;
    254   }
    255 
    256   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    257     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
    258   }
    259 
    260   // Accessor for patterns_.  Please don't use this outside of this file.
    261   const std::tuple<Patterns...>& patterns() const { return patterns_; }
    262 
    263  private:
    264   template <typename ItemType, size_t index>
    265   bool MatchImpl(ItemType* item, MatchOption option,
    266                  std::integral_constant<size_t, index>) const {
    267     // We don't need to do any EXPLAINing here; it's all correctly handled by
    268     // our sub-matchers (if any fail).
    269     return std::get<index>(patterns_).Match(item, option) &&
    270            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
    271   }
    272 
    273   template <typename ItemType>
    274   bool MatchImpl(ItemType* item, MatchOption option,
    275                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
    276     return true;
    277   }
    278 
    279   // Pretty-printing a conjunction has some special cases to make it easy to
    280   // read in the simple (common) case.
    281   //
    282   // If sizeof...(Patterns) == 1, prints as e.g.
    283   //
    284   //   a shape
    285   //
    286   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
    287   // shape") prints as
    288   //
    289   //   a shape compatible with f32[1,2]
    290   //
    291   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
    292   //
    293   //   a shape:
    294   //    * compatible with f32[1,2] AND
    295   //    * that represents a scalar
    296   //
    297   // Otherwise prints as:
    298   //
    299   //   all of:
    300   //    * foo AND
    301   //    * bar
    302   //
    303   template <size_t index>
    304   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
    305                       int64 indent) const {
    306     constexpr bool first_is_trivial =
    307         IsTrivialMatcher<typename std::remove_reference<decltype(
    308             std::get<0>(patterns_))>::type>::value;
    309     constexpr bool is_last = index == sizeof...(Patterns) - 1;
    310     const auto& submatcher = std::get<index>(patterns_);
    311 
    312     auto print_bulleted_item = [&] {
    313       *os << " * ";
    314       submatcher.DescribeTo(os, indent + 3);
    315       if (!is_last) {
    316         *os << " AND";
    317         Indent(os, indent);
    318       }
    319     };
    320 
    321     if (index == 0) {
    322       if (first_is_trivial || is_last) {
    323         submatcher.DescribeTo(os, indent + kIndentInc);
    324         if (sizeof...(Patterns) > 2) {
    325           *os << ":";
    326           Indent(os, indent);
    327         }
    328       } else {
    329         *os << "all of:";
    330         Indent(os, indent);
    331         print_bulleted_item();
    332       }
    333     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
    334       *os << " ";
    335       submatcher.DescribeTo(os, indent);
    336     } else {
    337       print_bulleted_item();
    338     }
    339     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
    340   }
    341 
    342   void DescribeToImpl(std::ostream* os,
    343                       std::integral_constant<size_t, sizeof...(Patterns)>,
    344                       int64 indent) const {}
    345 
    346   std::tuple<Patterns...> patterns_;
    347 };
    348 
    349 }  // namespace detail
    350 
    351 // Returns a pattern that represents the conjunction of all input patterns. All
    352 // patterns need to match in order to have the AllOf pattern match.
    353 template <typename Item, typename... Patterns>
    354 detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
    355     const Patterns&... patterns) {
    356   return detail::AllOfPattern<typename std::remove_const<Item>::type,
    357                               Patterns...>(patterns...);
    358 }
    359 
    360 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
    361 //
    362 // This transformation is necessary for good pretty-printing.
    363 template <typename Item, typename... InnerPs, typename... OuterPs>
    364 detail::AllOfPattern<typename std::remove_const<Item>::type, InnerPs...,
    365                      OuterPs...>
    366 AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
    367       const OuterPs&... outer_ps) {
    368   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
    369   auto make_all_of = [](const InnerPs&... inner_ps,
    370                         const OuterPs&... outer_ps) {
    371     return detail::AllOfPattern<typename std::remove_const<Item>::type,
    372                                 InnerPs..., OuterPs...>(inner_ps...,
    373                                                         outer_ps...);
    374   };
    375   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
    376                                                  std::make_tuple(outer_ps...)));
    377 }
    378 
    379 namespace detail {
    380 
    381 template <typename LayoutType, typename Impl>
    382 class LayoutPattern;
    383 
    384 // The base LayoutPattern implementation. Matches only if the layout is not
    385 // nullptr.
    386 class LayoutPatternBaseImpl {
    387  public:
    388   bool Match(const ::xla::Layout* layout, MatchOption option) const {
    389     if (layout == nullptr) {
    390       EXPLAIN << "Layout is null";
    391       return false;
    392     }
    393     return true;
    394   }
    395 
    396   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    397     *os << "a layout";
    398   }
    399 
    400   static constexpr bool kIsTrivialMatcher = true;
    401 };
    402 
    403 // A LayoutPattern implementation that matches only if the layout equals a
    404 // Layout proto.
    405 class LayoutPatternEqualImpl {
    406  public:
    407   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
    408       : layout_(layout) {}
    409 
    410   bool Match(const ::xla::Layout* layout, MatchOption option) const {
    411     if (!LayoutUtil::Equal(*layout_, *layout)) {
    412       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
    413               << " is not equal to expected "
    414               << LayoutUtil::HumanString(*layout_);
    415       return false;
    416     }
    417     return true;
    418   }
    419 
    420   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    421     *os << "equal to " << LayoutUtil::HumanString(*layout_);
    422   }
    423 
    424  private:
    425   const ::xla::Layout* layout_;
    426 };
    427 
    428 // A LayoutPattern implementation that matches only if the layout has a given
    429 // format.
    430 class LayoutPatternFormatImpl {
    431  public:
    432   explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
    433 
    434   bool Match(const ::xla::Layout* layout, MatchOption option) const {
    435     if (layout->format() != format_) {
    436       EXPLAIN << "Layout has format " << Format_Name(layout->format())
    437               << " but expected " << Format_Name(format_);
    438       return false;
    439     }
    440     return true;
    441   }
    442 
    443   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    444     *os << "with format " << Format_Name(format_);
    445   }
    446 
    447  private:
    448   Format format_;
    449 };
    450 
    451 // A pattern that matches Layouts.
    452 template <typename LayoutType, typename Impl>
    453 class LayoutPattern {
    454  private:
    455   template <typename NewImpl>
    456   auto AppendImpl(NewImpl new_impl) const
    457       -> LayoutPattern<LayoutType,
    458                        decltype(AllOf<Layout>(std::declval<Impl>(),
    459                                               std::move(new_impl)))> {
    460     auto new_allof = AllOf<Layout>(impl_, std::move(new_impl));
    461     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
    462                                                           matched_layout_);
    463   }
    464 
    465  public:
    466   explicit constexpr LayoutPattern(const Impl& impl,
    467                                    LayoutType** matched_layout)
    468       : impl_(impl), matched_layout_(matched_layout) {}
    469 
    470   // Returns true and captures the layout iff it matches the pattern.
    471   bool Match(const ::xla::Layout* layout, MatchOption option) const {
    472     if (impl_.Match(layout, option)) {
    473       if (option.capture && matched_layout_) {
    474         *matched_layout_ = layout;
    475       }
    476       return true;
    477     }
    478     return false;
    479   }
    480 
    481   // Returns true and captures the layout iff it matches the pattern.
    482   bool Match(::xla::Layout* layout, MatchOption option) const {
    483     if (impl_.Match(layout, option)) {
    484       if (option.capture && matched_layout_) {
    485         *matched_layout_ = layout;
    486       }
    487       return true;
    488     }
    489     return false;
    490   }
    491 
    492   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    493     impl_.DescribeTo(os, indent);
    494   }
    495 
    496   // Modifies the pattern to match only if the layout equals the given proto.
    497   // The layout must outlive the returned pattern.
    498   constexpr auto EqualTo(const ::xla::Layout* layout) const
    499       -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
    500     return AppendImpl(LayoutPatternEqualImpl(layout));
    501   }
    502 
    503   // Modifies the pattern to match only if the layout has a dense format.
    504   constexpr auto WithDenseFormat() const
    505       -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
    506     return AppendImpl(LayoutPatternFormatImpl(DENSE));
    507   }
    508 
    509   // Modifies the pattern to match only if the layout has a sparse format.
    510   constexpr auto WithSparseFormat() const
    511       -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
    512     return AppendImpl(LayoutPatternFormatImpl(SPARSE));
    513   }
    514 
    515  private:
    516   Impl impl_;
    517   LayoutType** matched_layout_;
    518 };
    519 
    520 template <typename Item, typename... Patterns>
    521 class AnyOfPattern {
    522  public:
    523   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
    524 
    525   bool Match(const Item* item, MatchOption option) const {
    526     return MatchImpl(item, option);
    527   }
    528 
    529   bool Match(Item* item, MatchOption option) const {
    530     return MatchImpl(item, option);
    531   }
    532 
    533   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    534     *os << "any of:";
    535     Indent(os, indent);
    536     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
    537   }
    538 
    539  private:
    540   template <typename ItemType>
    541   bool MatchImpl(ItemType* item, MatchOption option) const {
    542     // If we're generating an explanation, buffer it until we know we failed.
    543     absl::optional<std::stringstream> explanation;
    544     MatchOption new_option = option;
    545     if (option.explain_os) {
    546       new_option.explain_os = &explanation.emplace();
    547     }
    548     bool rv = MatchRecursiveImpl(item, new_option,
    549                                  std::integral_constant<size_t, 0>());
    550     if (!rv && option.explain_os) {
    551       EXPLAIN << "None of the following matchers succeeded:";
    552       EXPLAIN << explanation->str();
    553     }
    554     return rv;
    555   }
    556 
    557   template <typename ItemType, size_t index>
    558   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
    559                           std::integral_constant<size_t, index>) const {
    560     auto new_option = option;
    561     new_option.capture = false;
    562 
    563     absl::optional<std::stringstream> explanation;
    564     if (option.explain_os) {
    565       new_option.explain_os = &explanation.emplace();
    566     }
    567 
    568     // Try to match the sub-pattern without capturing behavior.
    569     if (std::get<index>(patterns_).Match(item, new_option)) {
    570       // Capture the branch.
    571       if (option.capture) {
    572         // TODO(timshen): Currently the behavior can be exponential. Optimize it
    573         // with memoization or recording the matched sub-pattern index, if it
    574         // takes too long to run.
    575         //
    576         // Specifically, the "memoization" approach is to create an empty
    577         // container with the key (pattern, instruction), and value as whether
    578         // matched or not.
    579         //
    580         // Alternatively, we may run the pattern matching with captures off, but
    581         // instead record a "trace" somewhere, indicating how exactly the
    582         // pattern matches the input. For example, the trace information for
    583         // AnyOf will be a runtime number indicate which sub-pattern is matched.
    584         // Then we run another pass to do captures only with the help of the
    585         // trace.
    586         bool matched = std::get<index>(patterns_).Match(item, option);
    587         DCHECK(matched);
    588       }
    589       return true;
    590     }
    591     if (option.explain_os) {
    592       EXPLAIN << "\nMatcher #" << index + 1;
    593       EXPLAIN << "\n - ";
    594       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
    595       EXPLAIN << "\nfailed with";
    596       EXPLAIN << "\n - ";
    597       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
    598     }
    599     return MatchRecursiveImpl(item, option,
    600                               std::integral_constant<size_t, index + 1>());
    601   }
    602 
    603   template <typename ItemType>
    604   bool MatchRecursiveImpl(
    605       ItemType* item, MatchOption option,
    606       std::integral_constant<size_t, sizeof...(Patterns)>) const {
    607     return false;
    608   }
    609 
    610   template <size_t index>
    611   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
    612                       int64 indent) const {
    613     *os << " - ";
    614     std::get<index>(patterns_).DescribeTo(os, indent + 3);
    615     if (index != sizeof...(Patterns) - 1) {
    616       *os << " OR";
    617       Indent(os, indent);
    618     }
    619     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
    620   }
    621 
    622   void DescribeToImpl(std::ostream* os,
    623                       std::integral_constant<size_t, sizeof...(Patterns)>,
    624                       int64 indent) const {}
    625 
    626   std::tuple<Patterns...> patterns_;
    627 };
    628 
    629 }  // namespace detail
    630 
    631 // Returns a pattern that represents the logical disjunction of the input
    632 // patterns. The returned pattern matches from left to right, and stops on the
    633 // first match.
    634 template <typename Item, typename... Patterns>
    635 detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
    636     const Patterns&... patterns) {
    637   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
    638                               Patterns...>(patterns...);
    639 }
    640 
    641 // Creates a layout pattern that will capture the matched layout in the
    642 // argument.
    643 inline constexpr detail::LayoutPattern<const ::xla::Layout,
    644                                        detail::LayoutPatternBaseImpl>
    645 Layout(const ::xla::Layout** matched_layout = nullptr) {
    646   return detail::LayoutPattern<const ::xla::Layout,
    647                                detail::LayoutPatternBaseImpl>(
    648       detail::LayoutPatternBaseImpl(), matched_layout);
    649 }
    650 
    651 // Creates a layout pattern that will capture the matched layout in the
    652 // argument.
    653 inline constexpr detail::LayoutPattern<::xla::Layout,
    654                                        detail::LayoutPatternBaseImpl>
    655 Layout(::xla::Layout** matched_layout) {
    656   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
    657       detail::LayoutPatternBaseImpl(), matched_layout);
    658 }
    659 
    660 namespace detail {
    661 
    662 template <typename ShapeType, typename Impl>
    663 class ShapePattern;
    664 
    665 // The base ShapePattern implementation. Matches only if the shape is not
    666 // nullptr.
    667 class ShapePatternBaseImpl {
    668  public:
    669   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    670     if (shape == nullptr) {
    671       EXPLAIN << "Shape is null";
    672     }
    673     return shape != nullptr;
    674   }
    675 
    676   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    677     *os << "a shape";
    678   }
    679 
    680   static constexpr bool kIsTrivialMatcher = true;
    681 };
    682 
    683 // A ShapePattern implementation that matches only if the shape equals a Shape
    684 // proto.
    685 class ShapePatternEqualImpl {
    686  public:
    687   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
    688       : shape_(shape) {}
    689 
    690   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    691     if (!ShapeUtil::Equal(*shape_, *shape)) {
    692       EXPLAIN << "Shape not equal to "
    693               << ShapeUtil::HumanStringWithLayout(*shape_);
    694       return false;
    695     }
    696     return true;
    697   }
    698 
    699   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    700     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
    701   }
    702 
    703  private:
    704   const ::xla::Shape* shape_;
    705 };
    706 
    707 // A ShapePattern implementation that matches only if the shape is compatible to
    708 // a Shape proto.
    709 class ShapePatternCompatibleImpl {
    710  public:
    711   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
    712       : shape_(shape) {}
    713 
    714   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    715     if (!ShapeUtil::Compatible(*shape_, *shape)) {
    716       EXPLAIN << "Shape not compatible with "
    717               << ShapeUtil::HumanString(*shape_);
    718       return false;
    719     }
    720     return true;
    721   }
    722 
    723   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    724     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
    725   }
    726 
    727  private:
    728   const ::xla::Shape* shape_;
    729 };
    730 
    731 // A ShapePattern implementation that matches only if the shape has a given
    732 // element type.
    733 class ShapePatternElementTypeImpl {
    734  public:
    735   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
    736       : element_type_(element_type) {}
    737 
    738   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    739     if (shape->element_type() != element_type_) {
    740       EXPLAIN << "Shape does not have element type "
    741               << PrimitiveType_Name(element_type_);
    742       return false;
    743     }
    744     return true;
    745   }
    746 
    747   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    748     *os << "with element type " << PrimitiveType_Name(element_type_);
    749   }
    750 
    751  private:
    752   PrimitiveType element_type_;
    753 };
    754 
    755 // A ShapePattern implementation that matches only if the shape is scalar.
    756 class ShapePatternIsScalarImpl {
    757  public:
    758   explicit constexpr ShapePatternIsScalarImpl() {}
    759 
    760   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    761     if (!ShapeUtil::IsScalar(*shape)) {
    762       EXPLAIN << "Shape is not a scalar";
    763       return false;
    764     }
    765     return true;
    766   }
    767 
    768   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    769     *os << "that represents a scalar";
    770   }
    771 };
    772 
    773 // A ShapePattern implementation that matches only if the shape is an array
    774 class ShapePatternIsArrayImpl {
    775  public:
    776   explicit constexpr ShapePatternIsArrayImpl() {}
    777 
    778   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    779     if (!shape->IsArray()) {
    780       EXPLAIN << "Shape is not an array";
    781       return false;
    782     }
    783     return true;
    784   }
    785 
    786   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    787     *os << "that represents an array";
    788   }
    789 };
    790 
    791 // A ShapePattern implementation that matches only if the shape is a tuple.
    792 class ShapePatternIsTupleImpl {
    793  public:
    794   explicit constexpr ShapePatternIsTupleImpl() {}
    795 
    796   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    797     if (!shape->IsTuple()) {
    798       EXPLAIN << "Shape is not a tuple";
    799       return false;
    800     }
    801     return true;
    802   }
    803 
    804   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    805     *os << "that represents a tuple";
    806   }
    807 };
    808 
    809 // A ShapePattern implementation that matches only if the shape is an effective
    810 // scalar.
    811 class ShapePatternEffectiveScalarImpl {
    812  public:
    813   explicit constexpr ShapePatternEffectiveScalarImpl() {}
    814 
    815   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    816     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
    817       EXPLAIN << "Shape is not an effective scalar";
    818       return false;
    819     }
    820     return true;
    821   }
    822 
    823   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    824     *os << "that is an effective scalar";
    825   }
    826 };
    827 
    828 // A ShapePattern implementation that matches only if the shape has a given
    829 // rank.
    830 class ShapePatternRankImpl {
    831  public:
    832   explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
    833 
    834   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    835     if (shape->rank() != rank_) {
    836       if (rank_ == 0) {
    837         EXPLAIN << "Shape is not a scalar";
    838       } else {
    839         EXPLAIN << "Shape does not have rank " << rank_;
    840       }
    841       return false;
    842     }
    843     return true;
    844   }
    845 
    846   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    847     if (rank_ == 0) {
    848       *os << "that is a scalar";
    849     } else {
    850       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
    851     }
    852   }
    853 
    854  private:
    855   int64 rank_;
    856 };
    857 
    858 // A ShapePattern implementation that matches only if the shape has a layout
    859 // that matches a given pattern.
    860 template <typename LayoutType, typename LayoutImpl>
    861 class ShapePatternLayoutImpl {
    862  public:
    863   explicit constexpr ShapePatternLayoutImpl(
    864       const LayoutPattern<LayoutType, LayoutImpl>& layout)
    865       : layout_(layout) {}
    866 
    867   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    868     return LayoutUtil::HasLayout(*shape) &&
    869            layout_.Match(&shape->layout(), option);
    870   }
    871 
    872   bool Match(Shape* shape, MatchOption option) const {
    873     if (!LayoutUtil::HasLayout(*shape)) {
    874       EXPLAIN << "Shape does not have a layout";
    875       return false;
    876     }
    877     if (!layout_.Match(shape->mutable_layout(), option)) {
    878       EXPLAIN << "\nin layout";
    879       return false;
    880     }
    881     return true;
    882   }
    883 
    884   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    885     *os << "with";
    886     Indent(os, indent + kIndentInc);
    887     layout_.DescribeTo(os, indent + kIndentInc);
    888   }
    889 
    890  private:
    891   LayoutPattern<LayoutType, LayoutImpl> layout_;
    892 };
    893 
    894 // A ShapePattern implementation that matches only if the shape has a subshape
    895 // that matches a given pattern.
    896 template <typename SubshapeType, typename SubshapeImpl>
    897 class ShapePatternSubshapeImpl {
    898  public:
    899   explicit ShapePatternSubshapeImpl(
    900       ShapeIndexView index,
    901       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
    902       : index_(index), subshape_(subshape) {}
    903 
    904   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    905     return MatchImpl(shape, option);
    906   }
    907 
    908   bool Match(::xla::Shape* shape, MatchOption option) const {
    909     return MatchImpl(shape, option);
    910   }
    911 
    912   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    913     *os << "with subshape at index " << index_.ToString() << " which is";
    914     Indent(os, indent + kIndentInc);
    915     subshape_.DescribeTo(os, indent + kIndentInc);
    916   }
    917 
    918  private:
    919   Shape* GetSubshape(Shape* shape) const {
    920     return ShapeUtil::GetMutableSubshape(shape, index_);
    921   }
    922   const Shape* GetSubshape(const Shape* shape) const {
    923     return &ShapeUtil::GetSubshape(*shape, index_);
    924   }
    925 
    926   template <typename ShapeType>
    927   bool MatchImpl(ShapeType* shape, MatchOption option) const {
    928     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
    929       EXPLAIN << "No subshape at " << index_.ToString();
    930       return false;
    931     }
    932     if (!subshape_.Match(GetSubshape(shape), option)) {
    933       EXPLAIN << "\nin subshape at " << index_.ToString();
    934       return false;
    935     }
    936     return true;
    937   }
    938 
    939   ShapeIndexView index_;
    940   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
    941 };
    942 
    943 // A pattern that matches Shapes.
    944 template <typename ShapeType, typename Impl>
    945 class ShapePattern {
    946  private:
    947   template <typename NewImpl>
    948   auto AppendImpl(NewImpl new_impl) const
    949       -> ShapePattern<ShapeType, decltype(AllOf<Shape>(std::declval<Impl>(),
    950                                                        std::move(new_impl)))> {
    951     auto new_all_of = AllOf<Shape>(impl_, std::move(new_impl));
    952     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
    953                                                          matched_shape_);
    954   }
    955 
    956  public:
    957   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
    958       : impl_(impl), matched_shape_(matched_shape) {}
    959 
    960   // Returns true and captures the shape iff it matches the pattern.
    961   bool Match(const ::xla::Shape* shape, MatchOption option) const {
    962     if (impl_.Match(shape, option)) {
    963       if (option.capture && matched_shape_) {
    964         *matched_shape_ = shape;
    965       }
    966       return true;
    967     }
    968     if (shape) {
    969       EXPLAIN << "\nin "
    970               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
    971                                       : ShapeUtil::HumanString(*shape));
    972     }
    973     return false;
    974   }
    975 
    976   // Returns true and captures the shape iff it matches the pattern.
    977   bool Match(::xla::Shape* shape, MatchOption option) const {
    978     if (impl_.Match(shape, option)) {
    979       if (option.capture && matched_shape_) {
    980         *matched_shape_ = shape;
    981       }
    982       return true;
    983     }
    984     EXPLAIN << "\nin "
    985             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
    986                                     : ShapeUtil::HumanString(*shape));
    987     return false;
    988   }
    989 
    990   void DescribeTo(std::ostream* os, int64 indent = 0) const {
    991     return impl_.DescribeTo(os, indent);
    992   }
    993 
    994   // Modifies the pattern to match only if the shape equals the given proto.
    995   // The layout must outlive the returned pattern.
    996   constexpr auto EqualTo(const ::xla::Shape* shape) const
    997       -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
    998     return AppendImpl(ShapePatternEqualImpl(shape));
    999   }
   1000 
   1001   // Modifies the pattern to match only if the shape is compatible to the given
   1002   // proto. The layout must outlive the returned pattern.
   1003   constexpr auto CompatibleTo(const ::xla::Shape* shape) const
   1004       -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
   1005     return AppendImpl(ShapePatternCompatibleImpl(shape));
   1006   }
   1007 
   1008   // Modifies the pattern to match only if the shape has the given element type.
   1009   constexpr auto WithElementType(PrimitiveType element_type) const
   1010       -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
   1011     return AppendImpl(ShapePatternElementTypeImpl(element_type));
   1012   }
   1013 
   1014   // Modifies the pattern to match only if the shape is scalar.
   1015   constexpr auto IsScalar() const
   1016       -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
   1017     return AppendImpl(ShapePatternIsScalarImpl());
   1018   }
   1019 
   1020   // Modifies the pattern to match only if the shape is an array.
   1021   constexpr auto IsArray() const
   1022       -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
   1023     return AppendImpl(ShapePatternIsArrayImpl());
   1024   }
   1025 
   1026   // Modifies the pattern to match only if the shape is a tuple.
   1027   constexpr auto IsTuple() const
   1028       -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
   1029     return AppendImpl(ShapePatternIsTupleImpl());
   1030   }
   1031 
   1032   constexpr auto IsEffectiveScalar() const
   1033       -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) {
   1034     return AppendImpl(ShapePatternEffectiveScalarImpl());
   1035   }
   1036 
   1037   // Modifies the pattern to match only if the shape has the given rank.
   1038   constexpr auto WithRank(int64 rank) const
   1039       -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
   1040     return AppendImpl(ShapePatternRankImpl(rank));
   1041   }
   1042 
   1043   // Modifies the pattern to match only if the shape has a layout that matches
   1044   // the given pattern.
   1045   template <typename LayoutType, typename LayoutImpl>
   1046   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
   1047       -> decltype(this->AppendImpl(
   1048           ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
   1049     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
   1050   }
   1051 
   1052   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
   1053       -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
   1054     return WithLayout(Layout().EqualTo(layout));
   1055   }
   1056 
   1057   constexpr auto IsDenseArray() const
   1058       -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
   1059     return WithLayout(Layout().WithDenseFormat());
   1060   }
   1061 
   1062   constexpr auto IsSparseArray() const
   1063       -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
   1064     return WithLayout(Layout().WithSparseFormat());
   1065   }
   1066 
   1067   // Modifies the pattern to match only if the shape has a subshape that matches
   1068   // the given pattern.
   1069   template <typename SubshapeType, typename SubshapeImpl>
   1070   auto WithSubshape(ShapeIndexView index,
   1071                     const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
   1072       const -> decltype(this->AppendImpl(
   1073           ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
   1074                                                                subshape))) {
   1075     return AppendImpl(
   1076         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
   1077   }
   1078 
   1079   ShapePattern<ShapeType,
   1080                AllOfPattern<Shape, Impl,
   1081                             ShapePatternSubshapeImpl<
   1082                                 const ::xla::Shape,
   1083                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
   1084                                              ShapePatternEqualImpl>>>>
   1085   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
   1086     return WithSubshape(index,
   1087                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
   1088                             ShapePatternBaseImpl(), nullptr)
   1089                             .EqualTo(shape));
   1090   }
   1091 
   1092   ShapePattern<ShapeType,
   1093                AllOfPattern<Shape, Impl,
   1094                             ShapePatternSubshapeImpl<
   1095                                 const ::xla::Shape,
   1096                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
   1097                                              ShapePatternCompatibleImpl>>>>
   1098   WithSubshapeCompatibleTo(ShapeIndexView index,
   1099                            const ::xla::Shape* shape) const {
   1100     return WithSubshape(index,
   1101                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
   1102                             ShapePatternBaseImpl(), nullptr)
   1103                             .CompatibleTo(shape));
   1104   }
   1105 
   1106  private:
   1107   Impl impl_;
   1108   ShapeType** matched_shape_;
   1109 };
   1110 
   1111 }  // namespace detail
   1112 
   1113 // Creates a shape pattern that will capture the matched layout in the argument.
   1114 inline constexpr detail::ShapePattern<const ::xla::Shape,
   1115                                       detail::ShapePatternBaseImpl>
   1116 Shape(const ::xla::Shape** matched_shape = nullptr) {
   1117   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
   1118       detail::ShapePatternBaseImpl(), matched_shape);
   1119 }
   1120 
   1121 // Creates a shape pattern that will capture the matched layout in the argument.
   1122 inline constexpr detail::ShapePattern<::xla::Shape,
   1123                                       detail::ShapePatternBaseImpl>
   1124 Shape(::xla::Shape** matched_shape) {
   1125   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
   1126       detail::ShapePatternBaseImpl(), matched_shape);
   1127 }
   1128 
   1129 namespace detail {
   1130 
   1131 // Overloads to get a const or non-const operand out of an instruction.
   1132 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
   1133   return instr->mutable_operand(idx);
   1134 }
   1135 inline const HloInstruction* HloOperand(const HloInstruction* instr,
   1136                                         int64 idx) {
   1137   return instr->operand(idx);
   1138 }
   1139 
   1140 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
   1141 // fewer %s and more shapes.
   1142 inline string InstToString(const HloInstruction* inst) {
   1143   return inst->ToString(
   1144       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
   1145 }
   1146 
   1147 template <typename HloInstructionType, typename Impl>
   1148 class HloInstructionPattern;
   1149 
   1150 // The base HloInstructionPattern implementation. Matches only if the
   1151 // instruction is not nullptr.
   1152 class HloInstructionPatternBaseImpl {
   1153  public:
   1154   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1155     if (inst == nullptr) {
   1156       EXPLAIN << "HloInstruction* is null";
   1157       return false;
   1158     }
   1159     return true;
   1160   }
   1161 
   1162   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1163     *os << "an HloInstruction";
   1164   }
   1165 
   1166   static constexpr bool kIsTrivialMatcher = true;
   1167 };
   1168 
   1169 // An HloInstructionPattern implementation that matches only if the instruction
   1170 // has a given name.
   1171 class HloInstructionPatternNameImpl {
   1172  public:
   1173   explicit HloInstructionPatternNameImpl(absl::string_view name)
   1174       : name_(name) {}
   1175 
   1176   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1177     if (inst->name() != name_) {
   1178       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
   1179       return false;
   1180     }
   1181     return true;
   1182   }
   1183 
   1184   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1185     *os << "named \"" << name_ << "\"";
   1186   }
   1187 
   1188  private:
   1189   absl::string_view name_;
   1190 };
   1191 
   1192 // An HloInstructionPattern implementation that matches only if the instruction
   1193 // equals a particular pointer.
   1194 class HloInstructionIsImpl {
   1195  public:
   1196   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
   1197 
   1198   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1199     if (inst != inst_) {
   1200       EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " ("
   1201               << InstToString(inst_) << ")";
   1202       return false;
   1203     }
   1204     return true;
   1205   }
   1206 
   1207   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1208     *os << "which is " << inst_ << " (" << InstToString(inst_) << ")";
   1209   }
   1210 
   1211  private:
   1212   const HloInstruction* inst_;
   1213 };
   1214 
   1215 // An HloInstructionPattern implementation that matches only if the instruction
   1216 // has a given opcode.
   1217 class HloInstructionPatternOpcodeImpl {
   1218  public:
   1219   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
   1220                                                      bool invert)
   1221       : opcode_(opcode), invert_(invert) {}
   1222 
   1223   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1224     if (invert_ && inst->opcode() == opcode_) {
   1225       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
   1226               << ", expected anything else";
   1227       return false;
   1228     }
   1229     if (!invert_ && inst->opcode() != opcode_) {
   1230       EXPLAIN << "HloInstruction doesn't have opcode "
   1231               << HloOpcodeString(opcode_);
   1232       return false;
   1233     }
   1234     return true;
   1235   }
   1236 
   1237   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1238     if (!invert_) {
   1239       *os << "with opcode " << HloOpcodeString(opcode_);
   1240     } else {
   1241       *os << "with any opcode other than " << HloOpcodeString(opcode_);
   1242     }
   1243   }
   1244 
   1245  private:
   1246   HloOpcode opcode_;
   1247   bool invert_;
   1248 };
   1249 
   1250 // An HloInstructionPattern implementation that matches only if the instruction
   1251 // has the given number of operands.
   1252 class HloInstructionPatternNumOperandsImpl {
   1253  public:
   1254   explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
   1255       : num_operands_(num_operands) {}
   1256 
   1257   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1258     if (inst->operand_count() != num_operands_) {
   1259       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
   1260       return false;
   1261     }
   1262     return true;
   1263   }
   1264 
   1265   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1266     *os << "with " << num_operands_ << " operand"
   1267         << (num_operands_ != 1 ? "s" : "");
   1268   }
   1269 
   1270  private:
   1271   int64 num_operands_;
   1272 };
   1273 
   1274 // An HloInstructionPattern implementation that matches only if the instruction
   1275 // has a shape that matches a given pattern.
   1276 template <typename ShapeType, typename ShapeImpl>
   1277 class HloInstructionPatternShapeImpl {
   1278  public:
   1279   explicit constexpr HloInstructionPatternShapeImpl(
   1280       const ShapePattern<ShapeType, ShapeImpl>& shape)
   1281       : shape_(shape) {}
   1282 
   1283   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1284     if (!shape_.Match(&inst->shape(), option)) {
   1285       EXPLAIN << "\nin output shape";
   1286       return false;
   1287     }
   1288     return true;
   1289   }
   1290 
   1291   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1292     if (!shape_.Match(inst->mutable_shape(), option)) {
   1293       EXPLAIN << "\nin output shape";
   1294       return false;
   1295     }
   1296     return true;
   1297   }
   1298 
   1299   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1300     *os << "outputting";
   1301     Indent(os, indent + kIndentInc);
   1302     shape_.DescribeTo(os, indent + kIndentInc);
   1303   }
   1304 
   1305  private:
   1306   ShapePattern<ShapeType, ShapeImpl> shape_;
   1307 };
   1308 
   1309 // An HloInstructionPattern implementation that matches only if the instruction
   1310 // has an operand that matches a given pattern.
   1311 template <typename OperandType, typename OperandImpl>
   1312 class HloInstructionPatternOperandImpl {
   1313  public:
   1314   explicit constexpr HloInstructionPatternOperandImpl(
   1315       int64 operand_index,
   1316       const HloInstructionPattern<OperandType, OperandImpl>& operand)
   1317       : operand_index_(operand_index), operand_(operand) {}
   1318 
   1319   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1320     return MatchImpl(inst, option);
   1321   }
   1322 
   1323   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1324     return MatchImpl(inst, option);
   1325   }
   1326 
   1327   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1328     *os << "with operand " << operand_index_ << " which is:";
   1329     Indent(os, indent + kIndentInc);
   1330     operand_.DescribeTo(os, indent + kIndentInc);
   1331   }
   1332 
   1333  private:
   1334   template <typename HloInstructionType>
   1335   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1336     if (operand_index_ >= inst->operand_count()) {
   1337       EXPLAIN << "desired operand index " << operand_index_
   1338               << " is out of bounds";
   1339       return false;
   1340     }
   1341     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
   1342       EXPLAIN << "\nin operand " << operand_index_;
   1343       return false;
   1344     }
   1345     return true;
   1346   }
   1347 
   1348   int64 operand_index_;
   1349   HloInstructionPattern<OperandType, OperandImpl> operand_;
   1350 };
   1351 
   1352 // Matches a binary instruction whose operands come in any order.
   1353 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
   1354           typename OperandImpl2>
   1355 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
   1356  public:
   1357   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
   1358       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
   1359       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
   1360       : op1_(op1), op2_(op2) {}
   1361 
   1362   bool Match(HloInstruction* inst, MatchOption option) const {
   1363     return MatchImpl(inst, option);
   1364   }
   1365 
   1366   bool Match(const HloInstruction* inst, MatchOption option) const {
   1367     return MatchImpl(inst, option);
   1368   }
   1369 
   1370   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1371     *os << "with two operands in either order:";
   1372     Indent(os, indent);
   1373     *os << " - ";
   1374     op1_.DescribeTo(os, indent + 3);
   1375     Indent(os, indent);
   1376     *os << " - ";
   1377     op2_.DescribeTo(os, indent + 3);
   1378   }
   1379 
   1380  private:
   1381   HloInstruction* operand(HloInstruction* inst, int64 idx) const {
   1382     return inst->mutable_operand(idx);
   1383   }
   1384   const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
   1385     return inst->operand(idx);
   1386   }
   1387 
   1388   template <typename HloInstructionType>
   1389   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1390     // We could implement this using AnyOf and AllOf matchers, but the templates
   1391     // get pretty difficult to debug, since any compile error herein becomes
   1392     // not-an-error via SFINAE.  Also this way lets us give better messages on
   1393     // failure.
   1394     if (inst->operand_count() != 2) {
   1395       EXPLAIN << "HloInstruction did not have two operands";
   1396       return false;
   1397     }
   1398 
   1399     // If we're not generating explanations, this is pretty simple.
   1400     if (!option.explain_os) {
   1401       auto try_match = [&](int64 idx1, int64 idx2) {
   1402         MatchOption new_option = option;
   1403         new_option.capture = false;
   1404         if (op1_.Match(operand(inst, idx1), new_option) &&
   1405             op2_.Match(operand(inst, idx2), new_option)) {
   1406           if (option.capture) {
   1407             bool matched = op1_.Match(operand(inst, idx1), option) &&
   1408                            op2_.Match(operand(inst, idx2), option);
   1409             DCHECK(matched);
   1410           }
   1411           return true;
   1412         }
   1413         return false;
   1414       };
   1415       return try_match(0, 1) || try_match(1, 0);
   1416     }
   1417 
   1418     // If we are generating explanations, we have some work to do in order to
   1419     // generate a helpful error.
   1420     //
   1421     // First, try all four operand/matcher combinations, recording the
   1422     // failure explanations separately from option.explain_os. matches[i][j]
   1423     // tells us if matcher_i matches operand j.
   1424     bool matches[/*matcher*/ 2][/*operand*/ 2];
   1425     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
   1426     for (int i = 0; i < 2; ++i) {
   1427       for (int j = 0; j < 2; ++j) {
   1428         MatchOption new_option = option;
   1429         new_option.capture = false;
   1430         new_option.explain_os = &explanations[i][j];
   1431         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
   1432                                : op2_.Match(operand(inst, j), new_option);
   1433       }
   1434     }
   1435 
   1436     // Check if the match succeeded.
   1437     for (int i = 0; i < 2; ++i) {
   1438       if (matches[0][i] && matches[1][(i + 1) % 2]) {
   1439         // Rerun the matches with capture enabled if necessary.
   1440         if (option.capture) {
   1441           auto* operand1 = operand(inst, i);
   1442           auto* operand2 = operand(inst, (i + 1) % 2);
   1443           bool matched =
   1444               op1_.Match(operand1, option) && op2_.Match(operand2, option);
   1445           DCHECK(matched);
   1446         }
   1447         return true;
   1448       }
   1449     }
   1450 
   1451     auto describe_matcher = [&](int matcher_idx) {
   1452       EXPLAIN << "\n - ";
   1453       if (matcher_idx == 0) {
   1454         op1_.DescribeTo(option.explain_os, /*indent=*/3);
   1455       } else {
   1456         CHECK_EQ(matcher_idx, 1);
   1457         op2_.DescribeTo(option.explain_os, /*indent=*/3);
   1458       }
   1459       for (int i = 0; i < 2; ++i) {
   1460         if (matches[matcher_idx][/*operand*/ i]) {
   1461           continue;
   1462         }
   1463         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
   1464         EXPLAIN << " - ";
   1465         EXPLAIN << absl::StrReplaceAll(
   1466             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
   1467       }
   1468     };
   1469 
   1470     // If we failed to match, one of the following is true:
   1471     //  1. op1 (op2) matches neither LHS nor RHS, or
   1472     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
   1473     // We print different explanations depending on which case we're in.
   1474 
   1475     // Case 1.
   1476     bool wrote_explanation = false;
   1477     for (int i = 0; !wrote_explanation && i < 2; ++i) {
   1478       if (!matches[i][0] && !matches[i][1]) {
   1479         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
   1480                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
   1481         describe_matcher(i);
   1482         wrote_explanation = true;
   1483       }
   1484     }
   1485 
   1486     // Case 2.
   1487     for (int i = 0; !wrote_explanation && i < 2; ++i) {
   1488       if (matches[/*matcher*/ 0][/*operand*/ i] &&
   1489           matches[/*matcher*/ 1][/*operand*/ i]) {
   1490         CHECK(!matches[0][(i + 1) % 2]);
   1491         CHECK(!matches[1][(i + 1) % 2]);
   1492         CHECK(!wrote_explanation);
   1493         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
   1494                 << " operand did not match either of the two matchers.  "
   1495                    "Specifically,";
   1496         describe_matcher(0);
   1497         EXPLAIN << "\nand";
   1498         describe_matcher(1);
   1499         wrote_explanation = true;
   1500       }
   1501     }
   1502 
   1503     CHECK(wrote_explanation);
   1504     return false;
   1505   }
   1506 
   1507   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
   1508   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
   1509 };
   1510 
   1511 // An HloInstructionPattern implementation that matches only if the instruction
   1512 // is a fusion node with a particular kind.
   1513 class HloInstructionPatternFusionKindImpl {
   1514  public:
   1515   explicit constexpr HloInstructionPatternFusionKindImpl(
   1516       ::xla::HloInstruction::FusionKind kind)
   1517       : kind_(kind) {}
   1518 
   1519   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1520     return MatchImpl(inst, option);
   1521   }
   1522 
   1523   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1524     return MatchImpl(inst, option);
   1525   }
   1526 
   1527   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1528     *os << "with fusion kind " << ToString(kind_);
   1529   }
   1530 
   1531  private:
   1532   template <typename HloInstructionType>
   1533   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1534     if (inst->opcode() != HloOpcode::kFusion) {
   1535       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
   1536               << "; it's not a fusion";
   1537       return false;
   1538     }
   1539     if (inst->fusion_kind() != kind_) {
   1540       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
   1541       return false;
   1542     }
   1543     return true;
   1544   }
   1545 
   1546   ::xla::HloInstruction::FusionKind kind_;
   1547 };
   1548 
   1549 // An HloInstructionPattern implementation that matches only if the instruction
   1550 // is a kGetTupleElement with a particular tuple index.
   1551 class HloInstructionPatternTupleIndexImpl {
   1552  public:
   1553   explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
   1554       : tuple_index_(tuple_index) {}
   1555 
   1556   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1557     return MatchImpl(inst, option);
   1558   }
   1559 
   1560   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1561     return MatchImpl(inst, option);
   1562   }
   1563 
   1564   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1565     *os << "which is a GTE with index " << tuple_index_;
   1566   }
   1567 
   1568  private:
   1569   template <typename HloInstructionType>
   1570   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1571     if (inst->opcode() != HloOpcode::kGetTupleElement) {
   1572       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
   1573               << "; it's not a GTE at all";
   1574       return false;
   1575     }
   1576     if (inst->tuple_index() != tuple_index_) {
   1577       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
   1578       return false;
   1579     }
   1580     return true;
   1581   }
   1582 
   1583   int64 tuple_index_;
   1584 };
   1585 
   1586 class HloInstructionPatternParameterNumImpl {
   1587  public:
   1588   explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
   1589       : parameter_num_(parameter_num) {}
   1590 
   1591   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1592     return MatchImpl(inst, option);
   1593   }
   1594 
   1595   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1596     return MatchImpl(inst, option);
   1597   }
   1598 
   1599   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1600     *os << "which is parameter " << parameter_num_;
   1601   }
   1602 
   1603  private:
   1604   template <typename HloInstructionType>
   1605   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1606     if (inst->opcode() != HloOpcode::kParameter ||
   1607         inst->parameter_number() != parameter_num_) {
   1608       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
   1609       return false;
   1610     }
   1611     return true;
   1612   }
   1613 
   1614   int64 parameter_num_;
   1615 };
   1616 
   1617 // Superclass that contains common code used by Op::WithOneUse() and
   1618 // Op::WithOneUser().
   1619 class HloInstructionPatternOneUseOrUserImpl {
   1620  protected:
   1621   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
   1622     if (inst->user_count() != 1) {
   1623       EXPLAIN << "HloInstruction has " << inst->user_count()
   1624               << " users, but expected exactly one.";
   1625       if (inst->user_count() > 1) {
   1626         EXPLAIN << "\nAll users:";
   1627         for (const HloInstruction* user : inst->users()) {
   1628           EXPLAIN << "\n - " << InstToString(user);
   1629         }
   1630       }
   1631       return false;
   1632     }
   1633     return true;
   1634   }
   1635 };
   1636 
   1637 class HloInstructionPatternOneUseImpl
   1638     : public HloInstructionPatternOneUseOrUserImpl {
   1639  public:
   1640   bool Match(const HloInstruction* inst, MatchOption option) const {
   1641     if (!MatchOneUser(inst, option)) {
   1642       return false;
   1643     }
   1644 
   1645     int64 use_count = absl::c_count_if(
   1646         inst->users()[0]->operands(),
   1647         [&](const HloInstruction* operand) { return operand == inst; });
   1648     if (use_count != 1) {
   1649       EXPLAIN << "HloInstruction is used " << use_count
   1650               << " times by its user, but is expected to be used just once: "
   1651               << InstToString(inst->users()[0]);
   1652       return false;
   1653     }
   1654     return true;
   1655   }
   1656 
   1657   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1658     *os << "which has exactly one use";
   1659   }
   1660 };
   1661 
   1662 class HloInstructionPatternOneUserImpl
   1663     : public HloInstructionPatternOneUseOrUserImpl {
   1664  public:
   1665   bool Match(const HloInstruction* inst, MatchOption option) const {
   1666     return MatchOneUser(inst, option);
   1667   }
   1668 
   1669   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1670     *os << "which has exactly one user (but possibly is used multiple times by "
   1671            "that instruction)";
   1672   }
   1673 };
   1674 
   1675 class HloInstructionPatternComparisonDirectionImpl {
   1676  public:
   1677   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
   1678       ComparisonDirection direction)
   1679       : direction_(direction) {}
   1680 
   1681   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1682     return MatchImpl(inst, option);
   1683   }
   1684 
   1685   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1686     return MatchImpl(inst, option);
   1687   }
   1688 
   1689   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1690     *os << "which has comparison direction "
   1691         << ComparisonDirectionToString(direction_);
   1692   }
   1693 
   1694  private:
   1695   template <typename HloInstructionType>
   1696   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
   1697     if (inst->opcode() != HloOpcode::kCompare ||
   1698         inst->comparison_direction() != direction_) {
   1699       EXPLAIN << "HloInstruction is not comparison "
   1700               << ComparisonDirectionToString(direction_);
   1701       return false;
   1702     }
   1703     return true;
   1704   }
   1705 
   1706   ComparisonDirection direction_;
   1707 };
   1708 
   1709 // Matches a constant scalar or effective scalar, optionally with a given value.
   1710 template <typename ScalarTy>
   1711 class HloConstantScalarImpl {
   1712  public:
   1713   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
   1714       : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
   1715 
   1716   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
   1717       : val_(val), match_effective_scalar_(match_effective_scalar) {}
   1718 
   1719   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1720     return MatchImpl(inst, option);
   1721   }
   1722 
   1723   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1724     return MatchImpl(inst, option);
   1725   }
   1726 
   1727   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   1728     *os << "which is a constant "
   1729         << (match_effective_scalar_ ? "effective " : "") << "scalar";
   1730     if (val_.has_value()) {
   1731       *os << " with value " << *val_;
   1732     }
   1733   }
   1734 
   1735  private:
   1736   template <typename InstTy>
   1737   bool MatchImpl(InstTy* inst, MatchOption option) const {
   1738     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
   1739     if (!const_inst) {
   1740       EXPLAIN << "HloInstruction is not a constant";
   1741       return false;
   1742     }
   1743     if (match_effective_scalar_ &&
   1744         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
   1745       EXPLAIN << "HloInstruction is not an effective scalar";
   1746       return false;
   1747     }
   1748     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
   1749       EXPLAIN << "HloInstruction is not a scalar";
   1750       return false;
   1751     }
   1752     if (!val_.has_value()) {
   1753       return true;
   1754     }
   1755 
   1756     // Check that literal == static_cast<LitearlTy>(val) and
   1757     // val == static_cast<ValTy>(literal).  This is sufficient to ensure that
   1758     // the two constant scalars are actually "equal".
   1759     auto val_literal = LiteralUtil::CreateR0(*val_);
   1760     auto literal_r0_or = const_inst->literal().Reshape({});
   1761     auto val_as_literal_ty_or =
   1762         val_literal.Convert(const_inst->shape().element_type());
   1763     if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) {
   1764       EXPLAIN << "could not construct relevant Literals (how did this happen?)";
   1765       return false;
   1766     }
   1767     auto literal_r0 = std::move(literal_r0_or).ValueOrDie();
   1768     auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie();
   1769     auto literal_r0_as_val_ty_or =
   1770         literal_r0.Convert(val_literal.shape().element_type());
   1771     bool rv = literal_r0_as_val_ty_or.ok() &&  //
   1772               literal_r0_as_val_ty_or.ValueOrDie() == val_literal &&
   1773               literal_r0 == val_as_literal_ty;
   1774     if (!rv) {
   1775       EXPLAIN << "HloInstruction's constant value "
   1776               << literal_r0.ToStringWithoutShape()
   1777               << " did not match expected value " << *val_;
   1778     }
   1779     return rv;
   1780   }
   1781 
   1782   absl::optional<ScalarTy> val_;
   1783   bool match_effective_scalar_;
   1784 };
   1785 
   1786 // A pattern that matches HloInstructions.
   1787 template <typename HloInstructionType, typename Impl>
   1788 class HloInstructionPattern {
   1789  private:
   1790   template <typename NewImpl>
   1791   auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern<
   1792       HloInstructionType, decltype(AllOf<HloInstruction>(
   1793                               std::declval<Impl>(), std::move(new_impl)))> {
   1794     auto new_allof = AllOf<HloInstruction>(impl_, std::move(new_impl));
   1795     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
   1796         std::move(new_allof), matched_inst_);
   1797   }
   1798 
   1799  public:
   1800   explicit constexpr HloInstructionPattern(const Impl& impl,
   1801                                            HloInstructionType** matched_inst)
   1802       : impl_(impl), matched_inst_(matched_inst) {}
   1803 
   1804   // Returns true and captures the instruction iff it matches the pattern.
   1805   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
   1806     if (impl_.Match(inst, option)) {
   1807       if (option.capture && matched_inst_) {
   1808         *matched_inst_ = inst;
   1809       }
   1810       return true;
   1811     }
   1812     if (inst != nullptr) {
   1813       EXPLAIN << "\nin " << InstToString(inst);
   1814     }
   1815     return false;
   1816   }
   1817 
   1818   // Returns true and captures the instruction iff it matches the pattern.
   1819   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
   1820     if (impl_.Match(inst, option)) {
   1821       if (option.capture && matched_inst_) {
   1822         *matched_inst_ = inst;
   1823       }
   1824       return true;
   1825     }
   1826     EXPLAIN << "\nin " << InstToString(inst);
   1827     return false;
   1828   }
   1829 
   1830   // Modifies the pattern to match only if the instruction has the given name.
   1831   auto WithName(absl::string_view name) const
   1832       -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
   1833     return AppendImpl(HloInstructionPatternNameImpl(name));
   1834   }
   1835 
   1836   // Modifies the pattern to match only if the instruction has the given opcode.
   1837   auto WithOpcode(HloOpcode opcode) const
   1838       -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
   1839                                                                    false))) {
   1840     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
   1841   }
   1842 
   1843   auto WithNumOperands(int64 num_operands) const -> decltype(
   1844       this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
   1845     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
   1846   }
   1847 
   1848   // Modifies the pattern to match only if the instruction does not have the
   1849   // given opcode.
   1850   auto WithoutOpcode(HloOpcode opcode) const
   1851       -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
   1852                                                                    true))) {
   1853     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
   1854   }
   1855 
   1856   constexpr auto Is(const HloInstruction* instr) const
   1857       -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) {
   1858     return AppendImpl(HloInstructionIsImpl(instr));
   1859   }
   1860 
   1861   // Modifies the pattern to match only if the instruction is a constant.
   1862   constexpr auto IsConstant() const
   1863       -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
   1864     return WithOpcode(HloOpcode::kConstant);
   1865   }
   1866 
   1867   constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl(
   1868       HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false))) {
   1869     return AppendImpl(
   1870         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
   1871   }
   1872 
   1873   // This does not check that T has the same type as the instruction, so e.g.
   1874   // IsConstantScalar(1.0) may match a constant of shape int32[].
   1875   template <typename ScalarTy>
   1876   constexpr auto IsConstantScalar(const ScalarTy& val) const
   1877       -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
   1878           val, /*match_effective_scalar=*/false))) {
   1879     return AppendImpl(
   1880         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
   1881   }
   1882 
   1883   constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl(
   1884       HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true))) {
   1885     return AppendImpl(
   1886         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
   1887   }
   1888 
   1889   template <typename ScalarTy>
   1890   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const
   1891       -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
   1892           val, /*match_effective_scalar=*/true))) {
   1893     return AppendImpl(
   1894         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
   1895   }
   1896 
   1897   // Modifies the pattern to match only if the instruction is not a constant.
   1898   constexpr auto IsNonConstant() const
   1899       -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
   1900     return WithoutOpcode(HloOpcode::kConstant);
   1901   }
   1902 
   1903   // Modifies the pattern to match only if the instruction has a shape that
   1904   // matches the given pattern.
   1905   template <typename ShapeType, typename ShapeImpl>
   1906   constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
   1907       const -> decltype(this->AppendImpl(
   1908           HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
   1909     return AppendImpl(
   1910         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
   1911   }
   1912 
   1913   // Make this a templated function to work around gcc 4.9.4 template infinite
   1914   // recursion bug.
   1915   template <typename Dummy = void>
   1916   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const
   1917       -> decltype(this->WithShape(Shape().EqualTo(shape))) {
   1918     return WithShape(Shape().EqualTo(shape));
   1919   }
   1920 
   1921   // Make this a templated function to work around gcc 4.9.4 template infinite
   1922   // recursion bug.
   1923   template <typename Dummy = void>
   1924   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const
   1925       -> decltype(this->WithShape(Shape().CompatibleTo(shape))) {
   1926     return WithShape(Shape().CompatibleTo(shape));
   1927   }
   1928 
   1929   // Modifies the pattern to match only if the instruction has an operand that
   1930   // matches the given pattern.
   1931   template <typename OperandType, typename OperandImpl>
   1932   constexpr auto WithOperand(
   1933       int64 operand_index,
   1934       const HloInstructionPattern<OperandType, OperandImpl>& operand) const
   1935       -> decltype(this->AppendImpl(
   1936           HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
   1937               operand_index, operand))) {
   1938     return AppendImpl(
   1939         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
   1940             operand_index, operand));
   1941   }
   1942 
   1943   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
   1944             typename OperandImpl2>
   1945   constexpr auto WithBinaryOperandsAnyOrder(
   1946       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
   1947       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const
   1948       -> decltype(this->AppendImpl(
   1949           HloInstructionPatternBinaryOperandsAnyOrderImpl<
   1950               OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1,
   1951                                                                       op2))) {
   1952     return AppendImpl(
   1953         HloInstructionPatternBinaryOperandsAnyOrderImpl<
   1954             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
   1955   }
   1956 
   1957   // Modifies the pattern to match only if the instruction is a fusion node with
   1958   // the given kind.
   1959   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
   1960       -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
   1961     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
   1962   }
   1963 
   1964   // Modifies the pattern to match only if the instruction is a
   1965   // get-tuple-element with the given tuple index.
   1966   constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
   1967       this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
   1968     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
   1969   }
   1970 
   1971   // Modifies the pattern to match only if the instruction is a parameter
   1972   // with the given parameter number.
   1973   constexpr auto WithParameterNum(int64 parameter_num) const -> decltype(
   1974       this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) {
   1975     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
   1976   }
   1977 
   1978   // Modifies the pattern to match if the instruction is used exactly once.
   1979   // Does not match if the instruction is used twice by the same user (e.g.
   1980   // multiply(x,x)).
   1981   constexpr auto WithOneUse() const
   1982       -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) {
   1983     return AppendImpl(HloInstructionPatternOneUseImpl());
   1984   }
   1985 
   1986   // Modifies the pattern to match if the instruction is used by exactly one
   1987   // other instruction.  Will match if the instruction is used twice, so long as
   1988   // it's by the same user (e.g.  multiply(x,x)).
   1989   constexpr auto WithOneUser() const
   1990       -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) {
   1991     return AppendImpl(HloInstructionPatternOneUserImpl());
   1992   }
   1993 
   1994   // Modifies the pattern to match only if the instruction has the given
   1995   // comparison direction.
   1996   auto WithComparisonDirection(ComparisonDirection direction) const
   1997       -> decltype(this->AppendImpl(
   1998           HloInstructionPatternComparisonDirectionImpl(direction))) {
   1999     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
   2000   }
   2001 
   2002   void DescribeTo(std::ostream* os, int64 indent = 0) const {
   2003     impl_.DescribeTo(os, indent);
   2004   }
   2005 
   2006  private:
   2007   Impl impl_;
   2008   HloInstructionType** matched_inst_;
   2009 };
   2010 
   2011 }  // namespace detail
   2012 
   2013 // Creates an instruction pattern that will capture the matched instruction in
   2014 // the argument.
   2015 inline constexpr detail::HloInstructionPattern<
   2016     const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
   2017 Op(const ::xla::HloInstruction** matched_inst = nullptr) {
   2018   return detail::HloInstructionPattern<const ::xla::HloInstruction,
   2019                                        detail::HloInstructionPatternBaseImpl>(
   2020       detail::HloInstructionPatternBaseImpl(), matched_inst);
   2021 }
   2022 
   2023 // Creates an instruction pattern that will capture the matched instruction in
   2024 // the argument.
   2025 inline constexpr detail::HloInstructionPattern<
   2026     ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
   2027 Op(::xla::HloInstruction** matched_inst) {
   2028   return detail::HloInstructionPattern<::xla::HloInstruction,
   2029                                        detail::HloInstructionPatternBaseImpl>(
   2030       detail::HloInstructionPatternBaseImpl(), matched_inst);
   2031 }
   2032 
   2033 // Helpers for nullary instructions.
   2034 #define XLA_NULLOP_PATTERN(NAME)                                      \
   2035   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
   2036     return Op().WithOpcode(HloOpcode::k##NAME);                       \
   2037   }                                                                   \
   2038                                                                       \
   2039   template <typename HloInstructionType>                              \
   2040   inline auto NAME(HloInstructionType** matched_inst)                 \
   2041       ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) {   \
   2042     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);           \
   2043   }
   2044 XLA_NULLOP_PATTERN(Constant)
   2045 XLA_NULLOP_PATTERN(Parameter)
   2046 XLA_NULLOP_PATTERN(Iota)
   2047 XLA_NULLOP_PATTERN(Rng)
   2048 #undef XLA_NULLOP_PATTERN
   2049 
   2050 // Helpers for unary instructions.
   2051 #define XLA_UNOP_PATTERN(NAME)                                        \
   2052   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
   2053     return Op().WithOpcode(HloOpcode::k##NAME);                       \
   2054   }                                                                   \
   2055                                                                       \
   2056   template <typename Arg>                                             \
   2057   inline auto NAME(Arg&& arg)->decltype(                              \
   2058       Op().WithOpcode(HloOpcode::k##NAME)                             \
   2059           .WithOperand(0, std::forward<Arg>(arg))) {                  \
   2060     return Op()                                                       \
   2061         .WithOpcode(HloOpcode::k##NAME)                               \
   2062         .WithOperand(0, std::forward<Arg>(arg));                      \
   2063   }                                                                   \
   2064                                                                       \
   2065   template <typename HloInstructionType, typename Arg>                \
   2066   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg)      \
   2067       ->decltype(Op(matched_inst)                                     \
   2068                      .WithOpcode(HloOpcode::k##NAME)                  \
   2069                      .WithOperand(0, std::forward<Arg>(arg))) {       \
   2070     return Op(matched_inst)                                           \
   2071         .WithOpcode(HloOpcode::k##NAME)                               \
   2072         .WithOperand(0, std::forward<Arg>(arg));                      \
   2073   }
   2074 XLA_UNOP_PATTERN(Abs)
   2075 XLA_UNOP_PATTERN(RoundNearestAfz)
   2076 XLA_UNOP_PATTERN(Bitcast)
   2077 XLA_UNOP_PATTERN(Broadcast)
   2078 XLA_UNOP_PATTERN(Ceil)
   2079 XLA_UNOP_PATTERN(Convert)
   2080 XLA_UNOP_PATTERN(Copy)
   2081 XLA_UNOP_PATTERN(Cos)
   2082 XLA_UNOP_PATTERN(AllReduce)
   2083 XLA_UNOP_PATTERN(Exp)
   2084 XLA_UNOP_PATTERN(Fft)
   2085 XLA_UNOP_PATTERN(Floor)
   2086 XLA_UNOP_PATTERN(GetTupleElement)
   2087 XLA_UNOP_PATTERN(Imag)
   2088 XLA_UNOP_PATTERN(Infeed)
   2089 XLA_UNOP_PATTERN(IsFinite)
   2090 XLA_UNOP_PATTERN(Log)
   2091 XLA_UNOP_PATTERN(Not)
   2092 XLA_UNOP_PATTERN(Negate)
   2093 XLA_UNOP_PATTERN(Real)
   2094 XLA_UNOP_PATTERN(Recv)
   2095 XLA_UNOP_PATTERN(RecvDone)
   2096 XLA_UNOP_PATTERN(ReducePrecision)
   2097 XLA_UNOP_PATTERN(Reshape)
   2098 XLA_UNOP_PATTERN(Reverse)
   2099 XLA_UNOP_PATTERN(Rsqrt)
   2100 XLA_UNOP_PATTERN(SendDone)
   2101 XLA_UNOP_PATTERN(Sign)
   2102 XLA_UNOP_PATTERN(Sin)
   2103 XLA_UNOP_PATTERN(Slice)
   2104 XLA_UNOP_PATTERN(Sqrt)
   2105 XLA_UNOP_PATTERN(Tanh)
   2106 XLA_UNOP_PATTERN(Transpose)
   2107 #undef XLA_UNOP_PATTERN
   2108 
   2109 // Helpers for binary instructions.
   2110 #define XLA_BINOP_PATTERN(NAME)                                             \
   2111   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {       \
   2112     return Op().WithOpcode(HloOpcode::k##NAME);                             \
   2113   }                                                                         \
   2114                                                                             \
   2115   template <typename Lhs, typename Rhs>                                     \
   2116   inline auto NAME(Lhs&& lhs, Rhs&& rhs)                                    \
   2117       ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                        \
   2118                      .WithOperand(0, std::forward<Lhs>(lhs))                \
   2119                      .WithOperand(1, std::forward<Rhs>(rhs))) {             \
   2120     return Op()                                                             \
   2121         .WithOpcode(HloOpcode::k##NAME)                                     \
   2122         .WithOperand(0, std::forward<Lhs>(lhs))                             \
   2123         .WithOperand(1, std::forward<Rhs>(rhs));                            \
   2124   }                                                                         \
   2125                                                                             \
   2126   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
   2127   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
   2128       ->decltype(Op(matched_inst)                                           \
   2129                      .WithOpcode(HloOpcode::k##NAME)                        \
   2130                      .WithOperand(0, std::forward<Lhs>(lhs))                \
   2131                      .WithOperand(1, std::forward<Rhs>(rhs))) {             \
   2132     return Op(matched_inst)                                                 \
   2133         .WithOpcode(HloOpcode::k##NAME)                                     \
   2134         .WithOperand(0, std::forward<Lhs>(lhs))                             \
   2135         .WithOperand(1, std::forward<Rhs>(rhs));                            \
   2136   }
   2137 
   2138 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                 \
   2139   XLA_BINOP_PATTERN(NAME)                                                   \
   2140                                                                             \
   2141   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
   2142   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs,  \
   2143                              Rhs&& rhs)                                     \
   2144       ->decltype(Op(matched_inst)                                           \
   2145                      .WithOpcode(HloOpcode::k##NAME)                        \
   2146                      .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),    \
   2147                                                  std::forward<Rhs>(rhs))) { \
   2148     return Op(matched_inst)                                                 \
   2149         .WithOpcode(HloOpcode::k##NAME)                                     \
   2150         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                 \
   2151                                     std::forward<Rhs>(rhs));                \
   2152   }                                                                         \
   2153   template <typename Lhs, typename Rhs>                                     \
   2154   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs)                          \
   2155       ->decltype(NAME##AnyOrder<const HloInstruction>(                      \
   2156           nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) {       \
   2157     return NAME##AnyOrder<const HloInstruction>(                            \
   2158         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));           \
   2159   }
   2160 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
   2161 XLA_BINOP_PATTERN(Atan2)
   2162 XLA_BINOP_PATTERN(Divide)
   2163 XLA_BINOP_PATTERN(Complex)
   2164 XLA_BINOP_PATTERN(Compare)
   2165 XLA_BINOP_PATTERN(Convolution)
   2166 XLA_BINOP_PATTERN(Dot)
   2167 XLA_BINOP_PATTERN(Gather)
   2168 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
   2169 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
   2170 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
   2171 XLA_BINOP_PATTERN(Outfeed)
   2172 XLA_BINOP_PATTERN(Pad)
   2173 XLA_BINOP_PATTERN(Power)
   2174 XLA_BINOP_PATTERN(ReduceWindow)
   2175 XLA_BINOP_PATTERN(Remainder)
   2176 XLA_BINOP_PATTERN(Send)
   2177 XLA_BINOP_PATTERN(Subtract)
   2178 XLA_COMMUTATIVE_BINOP_PATTERN(And)
   2179 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
   2180 XLA_BINOP_PATTERN(ShiftLeft)
   2181 XLA_BINOP_PATTERN(ShiftRightArithmetic)
   2182 XLA_BINOP_PATTERN(ShiftRightLogical)
   2183 #undef XLA_COMMUTATIVE_BINOP_PATTERN
   2184 #undef XLA_BINOP_PATTERN
   2185 
   2186 // Helpers for ternary instructions.
   2187 #define XLA_TERNOP_PATTERN(NAME)                                       \
   2188   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {  \
   2189     return Op().WithOpcode(HloOpcode::k##NAME);                        \
   2190   }                                                                    \
   2191                                                                        \
   2192   template <typename Arg0, typename Arg1, typename Arg2>               \
   2193   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2)              \
   2194       ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                   \
   2195                      .WithOperand(0, std::forward<Arg0>(arg0))         \
   2196                      .WithOperand(1, std::forward<Arg1>(arg1))         \
   2197                      .WithOperand(2, std::forward<Arg2>(arg2))) {      \
   2198     return Op()                                                        \
   2199         .WithOpcode(HloOpcode::k##NAME)                                \
   2200         .WithOperand(0, std::forward<Arg0>(arg0))                      \
   2201         .WithOperand(1, std::forward<Arg1>(arg1))                      \
   2202         .WithOperand(2, std::forward<Arg2>(arg2));                     \
   2203   }                                                                    \
   2204                                                                        \
   2205   template <typename HloInstructionType, typename Arg0, typename Arg1, \
   2206             typename Arg2>                                             \
   2207   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
   2208                    Arg1&& arg1, Arg2&& arg2)                           \
   2209       ->decltype(Op(matched_inst)                                      \
   2210                      .WithOpcode(HloOpcode::k##NAME)                   \
   2211                      .WithOperand(0, std::forward<Arg0>(arg0))         \
   2212                      .WithOperand(1, std::forward<Arg1>(arg1))         \
   2213                      .WithOperand(2, std::forward<Arg2>(arg2))) {      \
   2214     return Op(matched_inst)                                            \
   2215         .WithOpcode(HloOpcode::k##NAME)                                \
   2216         .WithOperand(0, std::forward<Arg0>(arg0))                      \
   2217         .WithOperand(1, std::forward<Arg1>(arg1))                      \
   2218         .WithOperand(2, std::forward<Arg2>(arg2));                     \
   2219   }
   2220 XLA_TERNOP_PATTERN(Clamp);
   2221 XLA_TERNOP_PATTERN(Scatter);
   2222 XLA_TERNOP_PATTERN(Select);
   2223 #undef XLA_TERNOP_PATTERN
   2224 
   2225 namespace detail {
   2226 template <typename Matcher, typename FirstArg>
   2227 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg)
   2228     -> decltype(m.WithOperand(operand_num, std::forward<FirstArg>(first_arg))) {
   2229   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
   2230 }
   2231 
   2232 template <typename Matcher, typename FirstArg, typename... Args>
   2233 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
   2234                          Args&&... args)
   2235     -> decltype(WithOperands(m.WithOperand(operand_num,
   2236                                            std::forward<FirstArg>(first_arg)),
   2237                              operand_num + 1, std::forward<Args>(args)...)) {
   2238   return WithOperands(
   2239       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
   2240       operand_num + 1, std::forward<Args>(args)...);
   2241 }
   2242 }  // namespace detail
   2243 
   2244 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
   2245   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {         \
   2246     return Op().WithOpcode(HloOpcode::k##NAME);                               \
   2247   }                                                                           \
   2248                                                                               \
   2249   template <typename... Args>                                                 \
   2250   inline auto NAME(Args&&... args)                                            \
   2251       ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME)     \
   2252                                           .WithNumOperands(sizeof...(Args)),  \
   2253                                       0, std::forward<Args>(args)...)) {      \
   2254     return detail::WithOperands(                                              \
   2255         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
   2256         /*operand_num=*/0, std::forward<Args>(args)...);                      \
   2257   }                                                                           \
   2258                                                                               \
   2259   template <typename HloInstructionType, typename... Args>                    \
   2260   inline auto NAME(HloInstructionType** matched_inst, Args&&... args)         \
   2261       ->decltype(detail::WithOperands(Op(matched_inst)                        \
   2262                                           .WithOpcode(HloOpcode::k##NAME)     \
   2263                                           .WithNumOperands(sizeof...(Args)),  \
   2264                                       0, std::forward<Args>(args)...)) {      \
   2265     return detail::WithOperands(Op(matched_inst)                              \
   2266                                     .WithOpcode(HloOpcode::k##NAME)           \
   2267                                     .WithNumOperands(sizeof...(Args)),        \
   2268                                 /*operand_num=*/0,                            \
   2269                                 std::forward<Args>(args)...);                 \
   2270   }
   2271 
   2272 // We could implement all ops as "variadic" ops, but it would make the
   2273 // already-bad compile errors even worse.
   2274 XLA_VARIADIC_OP_PATTERN(AfterAll);
   2275 XLA_VARIADIC_OP_PATTERN(Concatenate);
   2276 XLA_VARIADIC_OP_PATTERN(CustomCall);
   2277 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
   2278 XLA_VARIADIC_OP_PATTERN(Map)
   2279 XLA_VARIADIC_OP_PATTERN(Reduce);
   2280 XLA_VARIADIC_OP_PATTERN(Sort);
   2281 XLA_VARIADIC_OP_PATTERN(Tuple);
   2282 
   2283 // Helpers for comparison instructions.
   2284 #define XLA_COMPARE_PATTERN(NAME)                                              \
   2285   inline auto NAME()->decltype(                                                \
   2286       Op().WithOpcode(HloOpcode::kCompare)                                     \
   2287           .WithComparisonDirection(ComparisonDirection::k##NAME)) {            \
   2288     return Op()                                                                \
   2289         .WithOpcode(HloOpcode::kCompare)                                       \
   2290         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
   2291   }                                                                            \
   2292                                                                                \
   2293   template <typename Lhs, typename Rhs>                                        \
   2294   inline auto NAME(Lhs&& lhs, Rhs&& rhs)                                       \
   2295       ->decltype(Op().WithOpcode(HloOpcode::kCompare)                          \
   2296                      .WithOperand(0, std::forward<Lhs>(lhs))                   \
   2297                      .WithOperand(1, std::forward<Rhs>(rhs))                   \
   2298                      .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
   2299     return Op()                                                                \
   2300         .WithOpcode(HloOpcode::kCompare)                                       \
   2301         .WithOperand(0, std::forward<Lhs>(lhs))                                \
   2302         .WithOperand(1, std::forward<Rhs>(rhs))                                \
   2303         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
   2304   }                                                                            \
   2305                                                                                \
   2306   template <typename HloInstructionType, typename Lhs, typename Rhs>           \
   2307   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs)    \
   2308       ->decltype(Op(matched_inst)                                              \
   2309                      .WithOpcode(HloOpcode::kCompare)                          \
   2310                      .WithOperand(0, std::forward<Lhs>(lhs))                   \
   2311                      .WithOperand(1, std::forward<Rhs>(rhs))                   \
   2312                      .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
   2313     return Op(matched_inst)                                                    \
   2314         .WithOpcode(HloOpcode::kCompare)                                       \
   2315         .WithOperand(0, std::forward<Lhs>(lhs))                                \
   2316         .WithOperand(1, std::forward<Rhs>(rhs))                                \
   2317         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
   2318   }
   2319 
   2320 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                               \
   2321   XLA_COMPARE_PATTERN(NAME)                                                 \
   2322                                                                             \
   2323   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
   2324   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs,  \
   2325                              Rhs&& rhs)                                     \
   2326       ->decltype(Op(matched_inst)                                           \
   2327                      .WithOpcode(HloOpcode::kCompare)                       \
   2328                      .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),    \
   2329                                                  std::forward<Rhs>(rhs))) { \
   2330     return Op(matched_inst)                                                 \
   2331         .WithOpcode(HloOpcode::kCompare)                                    \
   2332         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                 \
   2333                                     std::forward<Rhs>(rhs));                \
   2334   }                                                                         \
   2335   template <typename Lhs, typename Rhs>                                     \
   2336   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs)                          \
   2337       ->decltype(NAME##AnyOrder<const HloInstruction>(                      \
   2338           nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) {       \
   2339     return NAME##AnyOrder<const HloInstruction>(                            \
   2340         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));           \
   2341   }
   2342 
   2343 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
   2344 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
   2345 XLA_COMPARE_PATTERN(Ge);
   2346 XLA_COMPARE_PATTERN(Gt);
   2347 XLA_COMPARE_PATTERN(Le);
   2348 XLA_COMPARE_PATTERN(Lt);
   2349 
   2350 // Helpers for matching non-constant instructions.
   2351 inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
   2352   return Op().IsNonConstant();
   2353 }
   2354 
   2355 template <typename HloInstructionType>
   2356 inline auto NonConstant(HloInstructionType** matched_inst)
   2357     -> decltype(Op(matched_inst).IsNonConstant()) {
   2358   return Op(matched_inst).IsNonConstant();
   2359 }
   2360 
   2361 // Add overloads for GetTupleElement which take a int64 specifying which tuple
   2362 // element is selected.
   2363 template <typename Arg>
   2364 inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
   2365     -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
   2366                     .WithOperand(0, std::forward<Arg>(arg))
   2367                     .WithTupleIndex(tuple_index)) {
   2368   return Op()
   2369       .WithOpcode(HloOpcode::kGetTupleElement)
   2370       .WithOperand(0, std::forward<Arg>(arg))
   2371       .WithTupleIndex(tuple_index);
   2372 }
   2373 
   2374 template <typename HloInstructionType, typename Arg>
   2375 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
   2376                             int64 tuple_index)
   2377     -> decltype(Op(matched_inst)
   2378                     .WithOpcode(HloOpcode::kGetTupleElement)
   2379                     .WithOperand(0, std::forward<Arg>(arg))
   2380                     .WithTupleIndex(tuple_index)) {
   2381   return Op(matched_inst)
   2382       .WithOpcode(HloOpcode::kGetTupleElement)
   2383       .WithOperand(0, std::forward<Arg>(arg))
   2384       .WithTupleIndex(tuple_index);
   2385 }
   2386 
   2387 // Add overloads for Parameter which take an int64 specifying the parameter
   2388 // number.
   2389 inline auto Parameter(int64 parameter_num) -> decltype(
   2390     Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) {
   2391   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
   2392 }
   2393 template <typename HloInstructionType>
   2394 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num)
   2395     -> decltype(Op(matched_inst)
   2396                     .WithOpcode(HloOpcode::kParameter)
   2397                     .WithParameterNum(parameter_num)) {
   2398   return Op(matched_inst)
   2399       .WithOpcode(HloOpcode::kParameter)
   2400       .WithParameterNum(parameter_num);
   2401 }
   2402 
   2403 inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) {
   2404   return Op().IsConstantScalar();
   2405 }
   2406 
   2407 template <typename HloInstructionType>
   2408 inline auto ConstantScalar(HloInstructionType** matched_inst)
   2409     -> decltype(Op(matched_inst).IsConstantScalar()) {
   2410   return Op(matched_inst).IsConstantScalar();
   2411 }
   2412 
   2413 template <typename ScalarTy>
   2414 inline auto ConstantScalar(ScalarTy val)
   2415     -> decltype(Op().IsConstantScalar(val)) {
   2416   return Op().IsConstantScalar(val);
   2417 }
   2418 
   2419 template <typename HloInstructionType, typename ScalarTy>
   2420 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val)
   2421     -> decltype(Op(matched_inst).IsConstantScalar(val)) {
   2422   return Op(matched_inst).IsConstantScalar(val);
   2423 }
   2424 
   2425 inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) {
   2426   return Op().IsConstantEffectiveScalar();
   2427 }
   2428 
   2429 template <typename HloInstructionType>
   2430 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst)
   2431     -> decltype(Op(matched_inst).IsConstantScalar()) {
   2432   return Op(matched_inst).IsConstantEffectiveScalar();
   2433 }
   2434 
   2435 template <typename ScalarTy>
   2436 inline auto ConstantEffectiveScalar(ScalarTy val)
   2437     -> decltype(Op().IsConstantEffectiveScalar(val)) {
   2438   return Op().IsConstantEffectiveScalar(val);
   2439 }
   2440 
   2441 template <typename HloInstructionType, typename ScalarTy>
   2442 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
   2443                                     ScalarTy val)
   2444     -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) {
   2445   return Op(matched_inst).IsConstantEffectiveScalar(val);
   2446 }
   2447 
   2448 }  // namespace match
   2449 
   2450 }  // namespace xla
   2451 
   2452 #undef EXPLAIN
   2453 #pragma pop_macro("EXPLAIN")
   2454 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
   2455