1 // test.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Function to test equality of two Fsts. 18 19 #ifndef FST_LIB_EQUAL_H__ 20 #define FST_LIB_EQUAL_H__ 21 22 #include "fst/lib/fst.h" 23 24 namespace fst { 25 26 // Tests if two Fsts have the same states and arcs in the same order. 27 template<class Arc> 28 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2) { 29 typedef typename Arc::StateId StateId; 30 typedef typename Arc::Weight Weight; 31 32 if (fst1.Start() != fst2.Start()) { 33 VLOG(1) << "Equal: mismatched start states"; 34 return false; 35 } 36 37 StateIterator< Fst<Arc> > siter1(fst1); 38 StateIterator< Fst<Arc> > siter2(fst2); 39 40 while (!siter1.Done() || !siter2.Done()) { 41 if (siter1.Done() || siter2.Done()) { 42 VLOG(1) << "Equal: mismatched # of states"; 43 return false; 44 } 45 StateId s1 = siter1.Value(); 46 StateId s2 = siter2.Value(); 47 if (s1 != s2) { 48 VLOG(1) << "Equal: mismatched states:" 49 << ", state1 = " << s1 50 << ", state2 = " << s2; 51 return false; 52 } 53 Weight final1 = fst1.Final(s1); 54 Weight final2 = fst2.Final(s2); 55 if (!ApproxEqual(final1, final2)) { 56 VLOG(1) << "Equal: mismatched final weights:" 57 << " state = " << s1 58 << ", final1 = " << final1 59 << ", final2 = " << final2; 60 return false; 61 } 62 ArcIterator< Fst<Arc> > aiter1(fst1, s1); 63 ArcIterator< Fst<Arc> > aiter2(fst2, s2); 64 for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) { 65 if (aiter1.Done() || aiter2.Done()) { 66 VLOG(1) << "Equal: mismatched # of arcs" 67 << " state = " << s1; 68 return false; 69 } 70 Arc arc1 = aiter1.Value(); 71 Arc arc2 = aiter2.Value(); 72 if (arc1.ilabel != arc2.ilabel) { 73 VLOG(1) << "Equal: mismatched arc input labels:" 74 << " state = " << s1 75 << ", arc = " << a 76 << ", ilabel1 = " << arc1.ilabel 77 << ", ilabel2 = " << arc2.ilabel; 78 return false; 79 } else if (arc1.olabel != arc2.olabel) { 80 VLOG(1) << "Equal: mismatched arc output labels:" 81 << " state = " << s1 82 << ", arc = " << a 83 << ", olabel1 = " << arc1.olabel 84 << ", olabel2 = " << arc2.olabel; 85 return false; 86 } else if (!ApproxEqual(arc1.weight, arc2.weight)) { 87 VLOG(1) << "Equal: mismatched arc weights:" 88 << " state = " << s1 89 << ", arc = " << a 90 << ", weight1 = " << arc1.weight 91 << ", weight2 = " << arc2.weight; 92 return false; 93 } else if (arc1.nextstate != arc2.nextstate) { 94 VLOG(1) << "Equal: mismatched input label:" 95 << " state = " << s1 96 << ", arc = " << a 97 << ", nextstate1 = " << arc1.nextstate 98 << ", nextstate2 = " << arc2.nextstate; 99 return false; 100 } 101 aiter1.Next(); 102 aiter2.Next(); 103 104 } 105 // Sanity checks 106 CHECK_EQ(fst1.NumArcs(s1), fst2.NumArcs(s2)); 107 CHECK_EQ(fst1.NumInputEpsilons(s1), fst2.NumInputEpsilons(s2)); 108 CHECK_EQ(fst1.NumOutputEpsilons(s1), fst2.NumOutputEpsilons(s2)); 109 110 siter1.Next(); 111 siter2.Next(); 112 } 113 return true; 114 } 115 116 } // namespace fst 117 118 119 #endif // FST_LIB_EQUAL_H__ 120