Home | History | Annotate | Download | only in fst
      1 // test.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Function to test equality of two Fsts.
     20 
     21 #ifndef FST_LIB_EQUAL_H__
     22 #define FST_LIB_EQUAL_H__
     23 
     24 #include <fst/fst.h>
     25 
     26 
     27 namespace fst {
     28 
     29 // Tests if two Fsts have the same states and arcs in the same order.
     30 template<class Arc>
     31 bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta = kDelta) {
     32   typedef typename Arc::StateId StateId;
     33   typedef typename Arc::Weight Weight;
     34 
     35   if (fst1.Start() != fst2.Start()) {
     36     VLOG(1) << "Equal: mismatched start states";
     37     return false;
     38   }
     39 
     40   StateIterator< Fst<Arc> > siter1(fst1);
     41   StateIterator< Fst<Arc> > siter2(fst2);
     42 
     43   while (!siter1.Done() || !siter2.Done()) {
     44     if (siter1.Done() || siter2.Done()) {
     45       VLOG(1) << "Equal: mismatched # of states";
     46       return false;
     47     }
     48     StateId s1 = siter1.Value();
     49     StateId s2 = siter2.Value();
     50     if (s1 != s2) {
     51       VLOG(1) << "Equal: mismatched states:"
     52               << ", state1 = " << s1
     53               << ", state2 = " << s2;
     54       return false;
     55     }
     56     Weight final1 = fst1.Final(s1);
     57     Weight final2 = fst2.Final(s2);
     58     if (!ApproxEqual(final1, final2, delta)) {
     59       VLOG(1) << "Equal: mismatched final weights:"
     60               << " state = " << s1
     61               << ", final1 = " << final1
     62               << ", final2 = " << final2;
     63       return false;
     64      }
     65     ArcIterator< Fst<Arc> > aiter1(fst1, s1);
     66     ArcIterator< Fst<Arc> > aiter2(fst2, s2);
     67     for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
     68       if (aiter1.Done() || aiter2.Done()) {
     69         VLOG(1) << "Equal: mismatched # of arcs"
     70                 << " state = " << s1;
     71         return false;
     72       }
     73       Arc arc1 = aiter1.Value();
     74       Arc arc2 = aiter2.Value();
     75       if (arc1.ilabel != arc2.ilabel) {
     76         VLOG(1) << "Equal: mismatched arc input labels:"
     77                 << " state = " << s1
     78                 << ", arc = " << a
     79                 << ", ilabel1 = " << arc1.ilabel
     80                 << ", ilabel2 = " << arc2.ilabel;
     81         return false;
     82       } else  if (arc1.olabel != arc2.olabel) {
     83         VLOG(1) << "Equal: mismatched arc output labels:"
     84                 << " state = " << s1
     85                 << ", arc = " << a
     86                 << ", olabel1 = " << arc1.olabel
     87                 << ", olabel2 = " << arc2.olabel;
     88         return false;
     89       } else  if (!ApproxEqual(arc1.weight, arc2.weight, delta)) {
     90         VLOG(1) << "Equal: mismatched arc weights:"
     91                 << " state = " << s1
     92                 << ", arc = " << a
     93                 << ", weight1 = " << arc1.weight
     94                 << ", weight2 = " << arc2.weight;
     95         return false;
     96       } else  if (arc1.nextstate != arc2.nextstate) {
     97         VLOG(1) << "Equal: mismatched input label:"
     98                 << " state = " << s1
     99                 << ", arc = " << a
    100                 << ", nextstate1 = " << arc1.nextstate
    101                 << ", nextstate2 = " << arc2.nextstate;
    102         return false;
    103       }
    104       aiter1.Next();
    105       aiter2.Next();
    106 
    107     }
    108     // Sanity checks: should never fail
    109     if (fst1.NumArcs(s1) != fst2.NumArcs(s2) ||
    110         fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2) ||
    111         fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) {
    112       FSTERROR() << "Equal: inconsistent arc/epsilon counts";
    113     }
    114 
    115     siter1.Next();
    116     siter2.Next();
    117   }
    118   return true;
    119 }
    120 
    121 }  // namespace fst
    122 
    123 
    124 #endif  // FST_LIB_EQUAL_H__
    125