Home | History | Annotate | Download | only in lib
      1 // test.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Function to test equality of two Fsts.
     18 
     19 #ifndef FST_LIB_EQUAL_H__
     20 #define FST_LIB_EQUAL_H__
     21 
     22 #include "fst/lib/fst.h"
     23 
     24 namespace fst {
     25 
     26 // Tests if two Fsts have the same states and arcs in the same order.
     27 template<class Arc>
     28 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
     29   typedef typename Arc::StateId StateId;
     30   typedef typename Arc::Weight Weight;
     31 
     32   if (fst1.Start() != fst2.Start()) {
     33     VLOG(1) << "Equal: mismatched start states";
     34     return false;
     35   }
     36 
     37   StateIterator< Fst<Arc> > siter1(fst1);
     38   StateIterator< Fst<Arc> > siter2(fst2);
     39 
     40   while (!siter1.Done() || !siter2.Done()) {
     41     if (siter1.Done() || siter2.Done()) {
     42       VLOG(1) << "Equal: mismatched # of states";
     43       return false;
     44     }
     45     StateId s1 = siter1.Value();
     46     StateId s2 = siter2.Value();
     47     if (s1 != s2) {
     48       VLOG(1) << "Equal: mismatched states:"
     49               << ", state1 = " << s1
     50               << ", state2 = " << s2;
     51       return false;
     52     }
     53     Weight final1 = fst1.Final(s1);
     54     Weight final2 = fst2.Final(s2);
     55     if (!ApproxEqual(final1, final2)) {
     56       VLOG(1) << "Equal: mismatched final weights:"
     57               << " state = " << s1
     58               << ", final1 = " << final1
     59               << ", final2 = " << final2;
     60       return false;
     61      }
     62     ArcIterator< Fst<Arc> > aiter1(fst1, s1);
     63     ArcIterator< Fst<Arc> > aiter2(fst2, s2);
     64     for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
     65       if (aiter1.Done() || aiter2.Done()) {
     66         VLOG(1) << "Equal: mismatched # of arcs"
     67                 << " state = " << s1;
     68         return false;
     69       }
     70       Arc arc1 = aiter1.Value();
     71       Arc arc2 = aiter2.Value();
     72       if (arc1.ilabel != arc2.ilabel) {
     73         VLOG(1) << "Equal: mismatched arc input labels:"
     74                 << " state = " << s1
     75                 << ", arc = " << a
     76                 << ", ilabel1 = " << arc1.ilabel
     77                 << ", ilabel2 = " << arc2.ilabel;
     78         return false;
     79       } else  if (arc1.olabel != arc2.olabel) {
     80         VLOG(1) << "Equal: mismatched arc output labels:"
     81                 << " state = " << s1
     82                 << ", arc = " << a
     83                 << ", olabel1 = " << arc1.olabel
     84                 << ", olabel2 = " << arc2.olabel;
     85         return false;
     86       } else  if (!ApproxEqual(arc1.weight, arc2.weight)) {
     87         VLOG(1) << "Equal: mismatched arc weights:"
     88                 << " state = " << s1
     89                 << ", arc = " << a
     90                 << ", weight1 = " << arc1.weight
     91                 << ", weight2 = " << arc2.weight;
     92         return false;
     93       } else  if (arc1.nextstate != arc2.nextstate) {
     94         VLOG(1) << "Equal: mismatched input label:"
     95                 << " state = " << s1
     96                 << ", arc = " << a
     97                 << ", nextstate1 = " << arc1.nextstate
     98                 << ", nextstate2 = " << arc2.nextstate;
     99         return false;
    100       }
    101       aiter1.Next();
    102       aiter2.Next();
    103 
    104     }
    105     // Sanity checks
    106     CHECK_EQ(fst1.NumArcs(s1), fst2.NumArcs(s2));
    107     CHECK_EQ(fst1.NumInputEpsilons(s1), fst2.NumInputEpsilons(s2));
    108     CHECK_EQ(fst1.NumOutputEpsilons(s1), fst2.NumOutputEpsilons(s2));
    109 
    110     siter1.Next();
    111     siter2.Next();
    112   }
    113   return true;
    114 }
    115 
    116 }  // namespace fst
    117 
    118 
    119 #endif  // FST_LIB_EQUAL_H__
    120