1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 19 #include "tensorflow/compiler/xla/test.h" 20 21 namespace xla { 22 namespace testing { 23 24 bool HloMatcher::MatchAndExplain( 25 const HloInstruction* instruction, 26 ::testing::MatchResultListener* listener) const { 27 // These cases are self-explanatory from the printed value. 28 if (!instruction || instruction->opcode() != opcode_) { 29 return false; 30 } 31 // Special case: no operand matchers means don't verify. 32 if (operands_.empty()) { 33 return true; 34 } 35 const auto& operands = instruction->operands(); 36 if (operands.size() != operands_.size()) { 37 *listener << "has too " 38 << (operands.size() > operands_.size() ? "many" : "few") 39 << " operands (got " << operands.size() << ", want " 40 << operands_.size() << ")"; 41 return false; 42 } 43 for (int index = 0; index < operands.size(); index++) { 44 ::testing::StringMatchResultListener inner_listener; 45 if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) { 46 if (listener->IsInterested()) { 47 *listener << "\noperand " << index << ":\n\t" 48 << operands[index]->ToString() 49 << "\ndoesn't match expected:\n\t"; 50 operands_[index].DescribeTo(listener->stream()); 51 string explanation = inner_listener.str(); 52 if (!explanation.empty()) { 53 *listener << ", " << explanation; 54 } 55 } 56 return false; 57 } 58 } 59 return true; 60 } 61 62 void HloMatcher::DescribeTo(::std::ostream* os) const { 63 *os << opcode_; 64 if (!operands_.empty()) { 65 *os << "("; 66 for (int i = 0; i < operands_.size(); i++) { 67 if (i > 0) { 68 *os << ", "; 69 } 70 operands_[i].DescribeTo(os); 71 } 72 *os << ")"; 73 } 74 } 75 76 bool HloParameterMatcher::MatchAndExplain( 77 const HloInstruction* instruction, 78 ::testing::MatchResultListener* listener) const { 79 if (!HloMatcher::MatchAndExplain(instruction, listener)) { 80 return false; 81 } 82 if (instruction->parameter_number() != parameter_number_) { 83 *listener << "has wrong parameter number (got " 84 << instruction->parameter_number() << ", want " 85 << parameter_number_ << ")"; 86 return false; 87 } 88 return true; 89 } 90 91 bool HloGetTupleElementMatcher::MatchAndExplain( 92 const HloInstruction* instruction, 93 ::testing::MatchResultListener* listener) const { 94 if (!HloMatcher::MatchAndExplain(instruction, listener)) { 95 return false; 96 } 97 if (instruction->tuple_index() != tuple_index_) { 98 *listener << "has wrong tuple index (got " << instruction->tuple_index() 99 << ", want " << tuple_index_ << ")"; 100 return false; 101 } 102 return true; 103 } 104 105 void HloCustomCallMatcher::DescribeTo(std::ostream* os) const { 106 HloMatcher::DescribeTo(os); 107 *os << " with call target that "; 108 call_target_matcher_.DescribeTo(os); 109 } 110 111 bool HloCustomCallMatcher::MatchAndExplain( 112 const HloInstruction* instruction, 113 ::testing::MatchResultListener* listener) const { 114 if (!HloMatcher::MatchAndExplain(instruction, listener)) { 115 return false; 116 } 117 ::testing::StringMatchResultListener sub_listener; 118 bool result = ExplainMatchResult( 119 call_target_matcher_, instruction->custom_call_target(), &sub_listener); 120 if (sub_listener.str().empty()) { 121 sub_listener << " that "; 122 123 std::stringstream desc_stream; 124 if (result) { 125 call_target_matcher_.DescribeTo(&desc_stream); 126 } else { 127 call_target_matcher_.DescribeNegationTo(&desc_stream); 128 } 129 sub_listener << desc_stream.str(); 130 } 131 *listener << "custom-call with call target" << sub_listener.str(); 132 return result; 133 } 134 135 } // namespace testing 136 137 void PrintTo(const HloInstruction* inst, ::std::ostream* os) { 138 *os << (inst ? inst->ToString() : "nullptr"); 139 } 140 141 void PrintTo(HloInstruction* inst, ::std::ostream* os) { 142 PrintTo(const_cast<const HloInstruction*>(inst), os); 143 } 144 145 } // namespace xla 146