1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ 18 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/test.h" 21 22 namespace xla { 23 namespace testing { 24 25 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> { 26 public: 27 HloMatcher(HloOpcode opcode, 28 std::vector<::testing::Matcher<const HloInstruction*>> operands) 29 : opcode_(opcode), operands_(operands) {} 30 31 bool MatchAndExplain(const HloInstruction* instruction, 32 ::testing::MatchResultListener* listener) const override; 33 34 void DescribeTo(::std::ostream* os) const override; 35 36 private: 37 HloOpcode opcode_; 38 std::vector<::testing::Matcher<const HloInstruction*>> operands_; 39 }; 40 41 // Custom matcher for parameters, which accepts a parameter number. 42 class HloParameterMatcher : public HloMatcher { 43 public: 44 explicit HloParameterMatcher(int64 parameter_number) 45 : HloMatcher(HloOpcode::kParameter, /*operands=*/{}), 46 parameter_number_(parameter_number) {} 47 48 bool MatchAndExplain(const HloInstruction* instruction, 49 ::testing::MatchResultListener* listener) const override; 50 51 private: 52 int64 parameter_number_; 53 }; 54 55 // Custom matcher for get-tuple-element instructions, which accepts a tuple 56 // index to match. 57 class HloGetTupleElementMatcher : public HloMatcher { 58 public: 59 HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand, 60 int64 tuple_index) 61 : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), 62 tuple_index_(tuple_index) {} 63 64 bool MatchAndExplain(const HloInstruction* instruction, 65 ::testing::MatchResultListener* listener) const override; 66 67 private: 68 int64 tuple_index_; 69 }; 70 71 // Custom matcher for custom-call instructions, which accepts a matcher for its 72 // call target. 73 class HloCustomCallMatcher : public HloMatcher { 74 public: 75 HloCustomCallMatcher( 76 ::testing::Matcher<string> call_target_matcher, 77 std::vector<::testing::Matcher<const HloInstruction*>> operands) 78 : HloMatcher(HloOpcode::kCustomCall, operands), 79 call_target_matcher_(call_target_matcher) {} 80 81 bool MatchAndExplain(const HloInstruction* instruction, 82 ::testing::MatchResultListener* listener) const override; 83 void DescribeTo(std::ostream* os) const override; 84 85 private: 86 ::testing::Matcher<string> call_target_matcher_; 87 }; 88 89 // HloInstruction* matchers for opcode and operands. Example: 90 // namespace op = xla::opcode_matchers; 91 // EXPECT_THAT(instruction, 92 // op::Add(op::Reshape(), op::Add(op::Reshape(), _))); 93 namespace opcode_matchers { 94 #define HLO_MATCHER(opcode) \ 95 template <typename... M> \ 96 ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \ 97 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \ 98 ::xla::HloOpcode::k##opcode, {operands...})); \ 99 } 100 HLO_MATCHER(Abs); 101 HLO_MATCHER(Add); 102 HLO_MATCHER(Bitcast); 103 HLO_MATCHER(Broadcast); 104 HLO_MATCHER(BatchNormGrad); 105 HLO_MATCHER(Call); 106 HLO_MATCHER(Ceil); 107 HLO_MATCHER(Clamp); 108 HLO_MATCHER(Concatenate); 109 HLO_MATCHER(Conditional); 110 HLO_MATCHER(Constant); 111 HLO_MATCHER(Convert); 112 HLO_MATCHER(Convolution); 113 HLO_MATCHER(Copy); 114 HLO_MATCHER(CrossReplicaSum); 115 HLO_MATCHER(Divide); 116 HLO_MATCHER(Dot); 117 HLO_MATCHER(DynamicSlice); 118 HLO_MATCHER(DynamicUpdateSlice); 119 HLO_MATCHER(Eq); 120 HLO_MATCHER(Exp); 121 HLO_MATCHER(Floor); 122 HLO_MATCHER(Fusion); 123 HLO_MATCHER(Ge); 124 HLO_MATCHER(Gt); 125 HLO_MATCHER(Infeed); 126 HLO_MATCHER(IsFinite); 127 HLO_MATCHER(Le); 128 HLO_MATCHER(Log); 129 HLO_MATCHER(And); 130 HLO_MATCHER(Not); 131 HLO_MATCHER(Or); 132 HLO_MATCHER(Lt); 133 HLO_MATCHER(Map); 134 HLO_MATCHER(Maximum); 135 HLO_MATCHER(Minimum); 136 HLO_MATCHER(Multiply); 137 HLO_MATCHER(Ne); 138 HLO_MATCHER(Negate); 139 HLO_MATCHER(Outfeed); 140 HLO_MATCHER(Pad); 141 HLO_MATCHER(Power); 142 HLO_MATCHER(Recv); 143 HLO_MATCHER(RecvDone); 144 HLO_MATCHER(Reduce); 145 HLO_MATCHER(ReducePrecision); 146 HLO_MATCHER(ReduceWindow); 147 HLO_MATCHER(Remainder); 148 HLO_MATCHER(Reshape); 149 HLO_MATCHER(Reverse); 150 HLO_MATCHER(Rng); 151 HLO_MATCHER(Select); 152 HLO_MATCHER(SelectAndScatter); 153 HLO_MATCHER(Send); 154 HLO_MATCHER(SendDone); 155 HLO_MATCHER(ShiftLeft); 156 HLO_MATCHER(ShiftRightLogical); 157 HLO_MATCHER(ShiftRightArithmetic); 158 HLO_MATCHER(Sign); 159 HLO_MATCHER(Slice); 160 HLO_MATCHER(Sort); 161 HLO_MATCHER(Subtract); 162 HLO_MATCHER(Tanh); 163 HLO_MATCHER(Trace); 164 HLO_MATCHER(Transpose); 165 HLO_MATCHER(Tuple); 166 HLO_MATCHER(While); 167 168 // The special cases below let you check additional information about the 169 // HloInstruction, beyond just its opcode and operands. In all cases you can 170 // still use the generic matcher which doesn't check this info. 171 // 172 // Feel free to add additional custom matchers below. 173 174 // - Parameter(N) matches parameter number N. 175 // - Parameter() matches any parameter. 176 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter( 177 int64 parameter_number) { 178 return ::testing::MakeMatcher( 179 new ::xla::testing::HloParameterMatcher(parameter_number)); 180 } 181 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() { 182 return ::testing::MakeMatcher( 183 new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); 184 } 185 186 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th 187 // tuple element of operand, while GetTupleElement(operand) matches any GTE 188 // operation on operand, and GetTupleElement() matches any GTE operation at all. 189 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement( 190 ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) { 191 return ::testing::MakeMatcher( 192 new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index)); 193 } 194 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement( 195 ::testing::Matcher<const HloInstruction*> operand) { 196 return ::testing::MakeMatcher( 197 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand})); 198 } 199 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() { 200 return ::testing::MakeMatcher( 201 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); 202 } 203 204 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call 205 // target T and the given operands. 206 // 207 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the 208 // given operands. 209 // 210 // - CustomCall() matches any CustomCall HLO at all. 211 template <typename... M> 212 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall( 213 ::testing::Matcher<string> call_target_matcher, M... operands) { 214 return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher( 215 call_target_matcher, {operands...})); 216 } 217 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to 218 // ::testing::Matcher<string>. In that case, we want to prefer the overload 219 // above. 220 template <typename FirstM, typename... M, 221 typename Dummy = typename std::enable_if< 222 !std::is_convertible<FirstM, ::testing::Matcher<string>>::value, 223 void>::type*> 224 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall( 225 FirstM operands_first, M... operands_rest) { 226 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( 227 HloOpcode::kCustomCall, {operands_first, operands_rest...})); 228 } 229 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() { 230 return ::testing::MakeMatcher( 231 new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); 232 } 233 234 #undef HLO_MATCHER 235 } // namespace opcode_matchers 236 237 // Helper to convert smart to raw pointers for matching. 238 template <typename Container> 239 std::vector<const HloInstruction*> Pointers(const Container& container) { 240 std::vector<const HloInstruction*> result; 241 result.reserve(container.size()); 242 for (const auto& entry : container) result.push_back(entry.get()); 243 return result; 244 } 245 246 } // namespace testing 247 248 // Tell GMock to print HloInstruction* by value, so error messages are nice. 249 // Has to be in the same namespace as 'HloInstruction'. 250 void PrintTo(const HloInstruction* inst, ::std::ostream* os); 251 void PrintTo(HloInstruction* inst, ::std::ostream* os); 252 253 } // namespace xla 254 255 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ 256