Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
     18 
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/test.h"
     21 
     22 namespace xla {
     23 namespace testing {
     24 
     25 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
     26  public:
     27   HloMatcher(HloOpcode opcode,
     28              std::vector<::testing::Matcher<const HloInstruction*>> operands)
     29       : opcode_(opcode), operands_(operands) {}
     30 
     31   bool MatchAndExplain(const HloInstruction* instruction,
     32                        ::testing::MatchResultListener* listener) const override;
     33 
     34   void DescribeTo(::std::ostream* os) const override;
     35 
     36  private:
     37   HloOpcode opcode_;
     38   std::vector<::testing::Matcher<const HloInstruction*>> operands_;
     39 };
     40 
     41 // Custom matcher for parameters, which accepts a parameter number.
     42 class HloParameterMatcher : public HloMatcher {
     43  public:
     44   explicit HloParameterMatcher(int64 parameter_number)
     45       : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
     46         parameter_number_(parameter_number) {}
     47 
     48   bool MatchAndExplain(const HloInstruction* instruction,
     49                        ::testing::MatchResultListener* listener) const override;
     50 
     51  private:
     52   int64 parameter_number_;
     53 };
     54 
     55 // Custom matcher for get-tuple-element instructions, which accepts a tuple
     56 // index to match.
     57 class HloGetTupleElementMatcher : public HloMatcher {
     58  public:
     59   HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
     60                             int64 tuple_index)
     61       : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
     62         tuple_index_(tuple_index) {}
     63 
     64   bool MatchAndExplain(const HloInstruction* instruction,
     65                        ::testing::MatchResultListener* listener) const override;
     66 
     67  private:
     68   int64 tuple_index_;
     69 };
     70 
     71 // Custom matcher for custom-call instructions, which accepts a matcher for its
     72 // call target.
     73 class HloCustomCallMatcher : public HloMatcher {
     74  public:
     75   HloCustomCallMatcher(
     76       ::testing::Matcher<string> call_target_matcher,
     77       std::vector<::testing::Matcher<const HloInstruction*>> operands)
     78       : HloMatcher(HloOpcode::kCustomCall, operands),
     79         call_target_matcher_(call_target_matcher) {}
     80 
     81   bool MatchAndExplain(const HloInstruction* instruction,
     82                        ::testing::MatchResultListener* listener) const override;
     83   void DescribeTo(std::ostream* os) const override;
     84 
     85  private:
     86   ::testing::Matcher<string> call_target_matcher_;
     87 };
     88 
     89 // HloInstruction* matchers for opcode and operands. Example:
     90 //   namespace op = xla::opcode_matchers;
     91 //   EXPECT_THAT(instruction,
     92 //               op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
     93 namespace opcode_matchers {
     94 #define HLO_MATCHER(opcode)                                                \
     95   template <typename... M>                                                 \
     96   ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
     97     return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(          \
     98         ::xla::HloOpcode::k##opcode, {operands...}));                      \
     99   }
    100 HLO_MATCHER(Abs);
    101 HLO_MATCHER(Add);
    102 HLO_MATCHER(Bitcast);
    103 HLO_MATCHER(Broadcast);
    104 HLO_MATCHER(BatchNormGrad);
    105 HLO_MATCHER(Call);
    106 HLO_MATCHER(Ceil);
    107 HLO_MATCHER(Clamp);
    108 HLO_MATCHER(Concatenate);
    109 HLO_MATCHER(Conditional);
    110 HLO_MATCHER(Constant);
    111 HLO_MATCHER(Convert);
    112 HLO_MATCHER(Convolution);
    113 HLO_MATCHER(Copy);
    114 HLO_MATCHER(CrossReplicaSum);
    115 HLO_MATCHER(Divide);
    116 HLO_MATCHER(Dot);
    117 HLO_MATCHER(DynamicSlice);
    118 HLO_MATCHER(DynamicUpdateSlice);
    119 HLO_MATCHER(Eq);
    120 HLO_MATCHER(Exp);
    121 HLO_MATCHER(Floor);
    122 HLO_MATCHER(Fusion);
    123 HLO_MATCHER(Ge);
    124 HLO_MATCHER(Gt);
    125 HLO_MATCHER(Infeed);
    126 HLO_MATCHER(IsFinite);
    127 HLO_MATCHER(Le);
    128 HLO_MATCHER(Log);
    129 HLO_MATCHER(And);
    130 HLO_MATCHER(Not);
    131 HLO_MATCHER(Or);
    132 HLO_MATCHER(Lt);
    133 HLO_MATCHER(Map);
    134 HLO_MATCHER(Maximum);
    135 HLO_MATCHER(Minimum);
    136 HLO_MATCHER(Multiply);
    137 HLO_MATCHER(Ne);
    138 HLO_MATCHER(Negate);
    139 HLO_MATCHER(Outfeed);
    140 HLO_MATCHER(Pad);
    141 HLO_MATCHER(Power);
    142 HLO_MATCHER(Recv);
    143 HLO_MATCHER(RecvDone);
    144 HLO_MATCHER(Reduce);
    145 HLO_MATCHER(ReducePrecision);
    146 HLO_MATCHER(ReduceWindow);
    147 HLO_MATCHER(Remainder);
    148 HLO_MATCHER(Reshape);
    149 HLO_MATCHER(Reverse);
    150 HLO_MATCHER(Rng);
    151 HLO_MATCHER(Select);
    152 HLO_MATCHER(SelectAndScatter);
    153 HLO_MATCHER(Send);
    154 HLO_MATCHER(SendDone);
    155 HLO_MATCHER(ShiftLeft);
    156 HLO_MATCHER(ShiftRightLogical);
    157 HLO_MATCHER(ShiftRightArithmetic);
    158 HLO_MATCHER(Sign);
    159 HLO_MATCHER(Slice);
    160 HLO_MATCHER(Sort);
    161 HLO_MATCHER(Subtract);
    162 HLO_MATCHER(Tanh);
    163 HLO_MATCHER(Trace);
    164 HLO_MATCHER(Transpose);
    165 HLO_MATCHER(Tuple);
    166 HLO_MATCHER(While);
    167 
    168 // The special cases below let you check additional information about the
    169 // HloInstruction, beyond just its opcode and operands.  In all cases you can
    170 // still use the generic matcher which doesn't check this info.
    171 //
    172 // Feel free to add additional custom matchers below.
    173 
    174 //  - Parameter(N) matches parameter number N.
    175 //  - Parameter() matches any parameter.
    176 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
    177     int64 parameter_number) {
    178   return ::testing::MakeMatcher(
    179       new ::xla::testing::HloParameterMatcher(parameter_number));
    180 }
    181 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
    182   return ::testing::MakeMatcher(
    183       new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
    184 }
    185 
    186 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
    187 // tuple element of operand, while GetTupleElement(operand) matches any GTE
    188 // operation on operand, and GetTupleElement() matches any GTE operation at all.
    189 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
    190     ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
    191   return ::testing::MakeMatcher(
    192       new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
    193 }
    194 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
    195     ::testing::Matcher<const HloInstruction*> operand) {
    196   return ::testing::MakeMatcher(
    197       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
    198 }
    199 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
    200   return ::testing::MakeMatcher(
    201       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
    202 }
    203 
    204 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
    205 //   target T and the given operands.
    206 //
    207 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
    208 //   given operands.
    209 //
    210 // - CustomCall() matches any CustomCall HLO at all.
    211 template <typename... M>
    212 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
    213     ::testing::Matcher<string> call_target_matcher, M... operands) {
    214   return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
    215       call_target_matcher, {operands...}));
    216 }
    217 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
    218 // ::testing::Matcher<string>.  In that case, we want to prefer the overload
    219 // above.
    220 template <typename FirstM, typename... M,
    221           typename Dummy = typename std::enable_if<
    222               !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
    223               void>::type*>
    224 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
    225     FirstM operands_first, M... operands_rest) {
    226   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
    227       HloOpcode::kCustomCall, {operands_first, operands_rest...}));
    228 }
    229 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
    230   return ::testing::MakeMatcher(
    231       new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
    232 }
    233 
    234 #undef HLO_MATCHER
    235 }  // namespace opcode_matchers
    236 
    237 // Helper to convert smart to raw pointers for matching.
    238 template <typename Container>
    239 std::vector<const HloInstruction*> Pointers(const Container& container) {
    240   std::vector<const HloInstruction*> result;
    241   result.reserve(container.size());
    242   for (const auto& entry : container) result.push_back(entry.get());
    243   return result;
    244 }
    245 
    246 }  // namespace testing
    247 
    248 // Tell GMock to print HloInstruction* by value, so error messages are nice.
    249 // Has to be in the same namespace as 'HloInstruction'.
    250 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
    251 void PrintTo(HloInstruction* inst, ::std::ostream* os);
    252 
    253 }  // namespace xla
    254 
    255 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
    256