Home | History | Annotate | Download | only in lib
      1 // compose.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // Class to compute the composition of two FSTs
     18 
     19 #ifndef FST_LIB_COMPOSE_H__
     20 #define FST_LIB_COMPOSE_H__
     21 
     22 #include <algorithm>
     23 
     24 #include <ext/hash_map>
     25 using __gnu_cxx::hash_map;
     26 
     27 #include "fst/lib/cache.h"
     28 #include "fst/lib/test-properties.h"
     29 
     30 namespace fst {
     31 
     32 // Enumeration of uint64 bits used to represent the user-defined
     33 // properties of FST composition (in the template parameter to
     34 // ComposeFstOptions<T>). The bits stand for extensions of generic FST
     35 // composition. ComposeFstOptions<> (all the bits unset) is the "plain"
     36 // compose without any extra extensions.
     37 enum ComposeTypes {
     38   // RHO: flags dealing with a special "rest" symbol in the FSTs.
     39   // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
     40   // may be set.
     41   COMPOSE_FST1_RHO    = 1ULL<<0,  // "Rest" symbol on the output side of fst1.
     42   COMPOSE_FST2_RHO    = 1ULL<<1,  // "Rest" symbol on the input side of fst2.
     43   COMPOSE_FST1_PHI    = 1ULL<<2,  // "Failure" symbol on the output
     44                                   // side of fst1.
     45   COMPOSE_FST2_PHI    = 1ULL<<3,  // "Failure" symbol on the input side
     46                                   // of fst2.
     47   COMPOSE_FST1_SIGMA  = 1ULL<<4,  // "Any" symbol on the output side of
     48                                   // fst1.
     49   COMPOSE_FST2_SIGMA  = 1ULL<<5,  // "Any" symbol on the input side of
     50                                   // fst2.
     51   // Optimization related bits.
     52   COMPOSE_GENERIC     = 1ULL<<32,  // Disables optimizations, applies
     53                                    // the generic version of the
     54                                    // composition algorithm. This flag
     55                                    // is used for internal testing
     56                                    // only.
     57 
     58   // -----------------------------------------------------------------
     59   // Auxiliary enum values denoting specific combinations of
     60   // bits. Internal use only.
     61   COMPOSE_RHO         = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
     62   COMPOSE_PHI         = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
     63   COMPOSE_SIGMA       = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
     64   COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,
     65 
     66   // -----------------------------------------------------------------
     67   // The following bits, denoting specific optimizations, are
     68   // typically set *internally* by the composition algorithm.
     69   COMPOSE_FST1_STRING = 1ULL<<33,  // fst1 is a string
     70   COMPOSE_FST2_STRING = 1ULL<<34,  // fst2 is a string
     71   COMPOSE_FST1_DET    = 1ULL<<35,  // fst1 is deterministic
     72   COMPOSE_FST2_DET    = 1ULL<<36,  // fst2 is deterministic
     73   COMPOSE_INTERNAL_MASK    = 0xffffffff00000000ULL
     74 };
     75 
     76 
     77 template <uint64 T = 0ULL>
     78 struct ComposeFstOptions : public CacheOptions {
     79   explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
     80   ComposeFstOptions() { }
     81 };
     82 
     83 
     84 // Abstract base for the implementation of delayed ComposeFst. The
     85 // concrete specializations are templated on the (uint64-valued)
     86 // properties of the FSTs being composed.
     87 template <class A>
     88 class ComposeFstImplBase : public CacheImpl<A> {
     89  public:
     90   using FstImpl<A>::SetType;
     91   using FstImpl<A>::SetProperties;
     92   using FstImpl<A>::Properties;
     93   using FstImpl<A>::SetInputSymbols;
     94   using FstImpl<A>::SetOutputSymbols;
     95 
     96   using CacheBaseImpl< CacheState<A> >::HasStart;
     97   using CacheBaseImpl< CacheState<A> >::HasFinal;
     98   using CacheBaseImpl< CacheState<A> >::HasArcs;
     99 
    100   typedef typename A::Label Label;
    101   typedef typename A::Weight Weight;
    102   typedef typename A::StateId StateId;
    103   typedef CacheState<A> State;
    104 
    105   ComposeFstImplBase(const Fst<A> &fst1,
    106                      const Fst<A> &fst2,
    107                      const CacheOptions &opts)
    108       :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
    109     SetType("compose");
    110     uint64 props1 = fst1.Properties(kFstProperties, false);
    111     uint64 props2 = fst2.Properties(kFstProperties, false);
    112     SetProperties(ComposeProperties(props1, props2), kCopyProperties);
    113 
    114     if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
    115       LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
    116                  << "does not match input symbol table of 2nd argument";
    117 
    118     SetInputSymbols(fst1.InputSymbols());
    119     SetOutputSymbols(fst2.OutputSymbols());
    120   }
    121 
    122   virtual ~ComposeFstImplBase() {
    123     delete fst1_;
    124     delete fst2_;
    125   }
    126 
    127   StateId Start() {
    128     if (!HasStart()) {
    129       StateId start = ComputeStart();
    130       if (start != kNoStateId) {
    131         this->SetStart(start);
    132       }
    133     }
    134     return CacheImpl<A>::Start();
    135   }
    136 
    137   Weight Final(StateId s) {
    138     if (!HasFinal(s)) {
    139       Weight final = ComputeFinal(s);
    140       this->SetFinal(s, final);
    141     }
    142     return CacheImpl<A>::Final(s);
    143   }
    144 
    145   virtual void Expand(StateId s) = 0;
    146 
    147   size_t NumArcs(StateId s) {
    148     if (!HasArcs(s))
    149       Expand(s);
    150     return CacheImpl<A>::NumArcs(s);
    151   }
    152 
    153   size_t NumInputEpsilons(StateId s) {
    154     if (!HasArcs(s))
    155       Expand(s);
    156     return CacheImpl<A>::NumInputEpsilons(s);
    157   }
    158 
    159   size_t NumOutputEpsilons(StateId s) {
    160     if (!HasArcs(s))
    161       Expand(s);
    162     return CacheImpl<A>::NumOutputEpsilons(s);
    163   }
    164 
    165   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    166     if (!HasArcs(s))
    167       Expand(s);
    168     CacheImpl<A>::InitArcIterator(s, data);
    169   }
    170 
    171   // Access to flags encoding compose options/optimizations etc.  (for
    172   // debugging).
    173   virtual uint64 ComposeFlags() const = 0;
    174 
    175  protected:
    176   virtual StateId ComputeStart() = 0;
    177   virtual Weight ComputeFinal(StateId s) = 0;
    178 
    179   const Fst<A> *fst1_;            // first input Fst
    180   const Fst<A> *fst2_;            // second input Fst
    181 };
    182 
    183 
    184 // The following class encapsulates implementation-dependent details
    185 // of state tuple lookup, i.e. a bijective mapping from triples of two
    186 // FST states and an epsilon filter state to the corresponding state
    187 // IDs of the fst resulting from composition. The mapping must
    188 // implement the [] operator in the style of STL associative
    189 // containers (map, hash_map), i.e. table[x] must return a reference
    190 // to the value associated with x. If x is an unassigned tuple, the
    191 // operator must automatically associate x with value 0.
    192 //
    193 // NB: "table[x] == 0" for unassigned tuples x is required by the
    194 // following off-by-one device used in the implementation of
    195 // ComposeFstImpl. The value stored in the table is equal to tuple ID
    196 // plus one, i.e. it is always a strictly positive number. Therefore,
    197 // table[x] is equal to 0 if and only if x is an unassigned tuple (in
    198 // which the algorithm assigns a new ID to x, and sets table[x] -
    199 // stored in a reference - to "new ID + 1"). This form of lookup is
    200 // more efficient than calling "find(x)" and "insert(make_pair(x, new
    201 // ID))" if x is an unassigned tuple.
    202 //
    203 // The generic implementation is a wrapper around a hash_map.
    204 template <class A, uint64 T>
    205 class ComposeStateTable {
    206  public:
    207   typedef typename A::StateId StateId;
    208 
    209   struct StateTuple {
    210     StateTuple() {}
    211     StateTuple(StateId s1, StateId s2, int f)
    212         : state_id1(s1), state_id2(s2), filt(f) {}
    213     StateId state_id1;  // state Id on fst1
    214     StateId state_id2;  // state Id on fst2
    215     int filt;           // epsilon filter state
    216   };
    217 
    218   ComposeStateTable() {
    219     StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
    220   }
    221 
    222   // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
    223   // inserted into 'table_' (standard STL container semantics). Since
    224   // StateId is a built-in type, the explicit default constructor call
    225   // StateId() returns 0.
    226   StateId &operator[](const StateTuple &tuple) {
    227     return table_[tuple];
    228   }
    229 
    230  private:
    231   // Comparison object for hashing StateTuple(s).
    232   class StateTupleEqual {
    233    public:
    234     bool operator()(const StateTuple& x, const StateTuple& y) const {
    235       return x.state_id1 == y.state_id1 &&
    236              x.state_id2 == y.state_id2 &&
    237              x.filt == y.filt;
    238     }
    239   };
    240 
    241   static const int kPrime0 = 7853;
    242   static const int kPrime1 = 7867;
    243 
    244   // Hash function for StateTuple to Fst states.
    245   class StateTupleKey {
    246    public:
    247     size_t operator()(const StateTuple& x) const {
    248       return static_cast<size_t>(x.state_id1 +
    249                                  x.state_id2 * kPrime0 +
    250                                  x.filt * kPrime1);
    251     }
    252   };
    253 
    254   // Lookup table mapping state tuples to state IDs.
    255   typedef hash_map<StateTuple,
    256                          StateId,
    257                          StateTupleKey,
    258                          StateTupleEqual> StateTable;
    259  // Actual table data.
    260   StateTable table_;
    261 
    262   DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
    263 };
    264 
    265 
    266 // State tuple lookup table for the composition of a string FST with a
    267 // deterministic FST.  The class maps state tuples to their unique IDs
    268 // (i.e. states of the ComposeFst). Main optimization: due to the
    269 // 1-to-1 correspondence between the states of the input string FST
    270 // and those of the resulting (string) FST, a state tuple (s1, s2) is
    271 // simply mapped to StateId s1. Hence, we use an STL vector as a
    272 // lookup table. Template argument Fst1IsString specifies which FST is
    273 // a string (this determines whether or not we index the lookup table
    274 // by the first or by the second state).
    275 template <class A, bool Fst1IsString>
    276 class StringDetComposeStateTable {
    277  public:
    278   typedef typename A::StateId StateId;
    279 
    280   struct StateTuple {
    281     typedef typename A::StateId StateId;
    282     StateTuple() {}
    283     StateTuple(StateId s1, StateId s2, int /* f */)
    284         : state_id1(s1), state_id2(s2) {}
    285     StateId state_id1;  // state Id on fst1
    286     StateId state_id2;  // state Id on fst2
    287     static const int filt = 0;  // 'fake' epsilon filter - only needed
    288                                 // for API compatibility
    289   };
    290 
    291   StringDetComposeStateTable() {}
    292 
    293   // Subscript operator. Behaves in a way similar to its map/hash_map
    294   // counterpart, i.e. returns a reference to the value associated
    295   // with 'tuple', inserting a 0 value if 'tuple' is unassigned.
    296   StateId &operator[](const StateTuple &tuple) {
    297     StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
    298     if (index >= (StateId)data_.size()) {
    299       // NB: all values in [old_size; index] are initialized to 0.
    300       data_.resize(index + 1);
    301     }
    302     return data_[index];
    303   }
    304 
    305  private:
    306   vector<StateId> data_;
    307 
    308   DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
    309 };
    310 
    311 
    312 // Specializations of ComposeStateTable for the string/det case.
    313 // Both inherit from StringDetComposeStateTable.
    314 template <class A>
    315 class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
    316     : public StringDetComposeStateTable<A, true> { };
    317 
    318 template <class A>
    319 class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
    320     : public StringDetComposeStateTable<A, false> { };
    321 
    322 
    323 // Parameterized implementation of FST composition for a pair of FSTs
    324 // matching the property bit vector T. If possible,
    325 // instantiation-specific switches in the code are based on the values
    326 // of the bits in T, which are known at compile time, so unused code
    327 // should be optimized away by the compiler.
    328 template <class A, uint64 T>
    329 class ComposeFstImpl : public ComposeFstImplBase<A> {
    330   typedef typename A::StateId StateId;
    331   typedef typename A::Label   Label;
    332   typedef typename A::Weight  Weight;
    333   using FstImpl<A>::SetType;
    334   using FstImpl<A>::SetProperties;
    335 
    336   enum FindType { FIND_INPUT  = 1,          // find input label on fst2
    337                   FIND_OUTPUT = 2,          // find output label on fst1
    338                   FIND_BOTH   = 3 };        // find choice state dependent
    339 
    340   typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
    341   typedef typename StateTupleTable::StateTuple StateTuple;
    342 
    343  public:
    344   ComposeFstImpl(const Fst<A> &fst1,
    345                  const Fst<A> &fst2,
    346                  const CacheOptions &opts)
    347       :ComposeFstImplBase<A>(fst1, fst2, opts) {
    348 
    349     bool osorted = fst1.Properties(kOLabelSorted, false);
    350     bool isorted = fst2.Properties(kILabelSorted, false);
    351 
    352     switch (T & COMPOSE_SPECIAL_SYMBOLS) {
    353       case COMPOSE_FST1_RHO:
    354       case COMPOSE_FST1_PHI:
    355       case COMPOSE_FST1_SIGMA:
    356         if (!osorted || FLAGS_fst_verify_properties)
    357           osorted = fst1.Properties(kOLabelSorted, true);
    358         if (!osorted)
    359           LOG(FATAL) << "ComposeFst: 1st argument not output label "
    360                      << "sorted (special symbols present)";
    361         break;
    362       case COMPOSE_FST2_RHO:
    363       case COMPOSE_FST2_PHI:
    364       case COMPOSE_FST2_SIGMA:
    365         if (!isorted || FLAGS_fst_verify_properties)
    366           isorted = fst2.Properties(kILabelSorted, true);
    367         if (!isorted)
    368           LOG(FATAL) << "ComposeFst: 2nd argument not input label "
    369                      << "sorted (special symbols present)";
    370         break;
    371       case 0:
    372         if (!isorted && !osorted || FLAGS_fst_verify_properties) {
    373           osorted = fst1.Properties(kOLabelSorted, true);
    374           if (!osorted)
    375             isorted = fst2.Properties(kILabelSorted, true);
    376         }
    377         break;
    378       default:
    379         LOG(FATAL)
    380           << "ComposeFst: More than one special symbol used in composition";
    381     }
    382 
    383     if (isorted && (T & COMPOSE_FST2_SIGMA)) {
    384       find_type_ = FIND_INPUT;
    385     } else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
    386       find_type_ = FIND_OUTPUT;
    387     } else if (isorted && (T & COMPOSE_FST2_PHI)) {
    388       find_type_ = FIND_INPUT;
    389     } else if (osorted && (T & COMPOSE_FST1_PHI)) {
    390       find_type_ = FIND_OUTPUT;
    391     } else if (isorted && (T & COMPOSE_FST2_RHO)) {
    392       find_type_ = FIND_INPUT;
    393     } else if (osorted && (T & COMPOSE_FST1_RHO)) {
    394       find_type_ = FIND_OUTPUT;
    395     } else if (isorted && (T & COMPOSE_FST1_STRING)) {
    396       find_type_ = FIND_INPUT;
    397     } else if(osorted && (T & COMPOSE_FST2_STRING)) {
    398       find_type_ = FIND_OUTPUT;
    399     } else if (isorted && osorted) {
    400       find_type_ = FIND_BOTH;
    401     } else if (isorted) {
    402       find_type_ = FIND_INPUT;
    403     } else if (osorted) {
    404       find_type_ = FIND_OUTPUT;
    405     } else {
    406       LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
    407                  << "and 2nd argument is not input label sorted";
    408     }
    409   }
    410 
    411   // Finds/creates an Fst state given a StateTuple.  Only creates a new
    412   // state if StateTuple is not found in the state hash.
    413   //
    414   // The method exploits the following device: all pairs stored in the
    415   // associative container state_tuple_table_ are of the form (tuple,
    416   // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
    417   // been stored previously. For unassigned tuples, the call to
    418   // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
    419   // result, state_tuple_table_[tuple] == 0 iff tuple is new.
    420   StateId FindState(const StateTuple& tuple) {
    421     StateId &assoc_value = state_tuple_table_[tuple];
    422     if (assoc_value == 0) {  // tuple wasn't present in lookup table:
    423                              // assign it a new ID.
    424       state_tuples_.push_back(tuple);
    425       assoc_value = state_tuples_.size();
    426     }
    427     return assoc_value - 1;  // NB: assoc_value = ID + 1
    428   }
    429 
    430   // Generates arc for composition state s from matched input Fst arcs.
    431   void AddArc(StateId s, const A &arca, const A &arcb, int f,
    432               bool find_input) {
    433     A arc;
    434     if (find_input) {
    435       arc.ilabel = arcb.ilabel;
    436       arc.olabel = arca.olabel;
    437       arc.weight = Times(arcb.weight, arca.weight);
    438       StateTuple tuple(arcb.nextstate, arca.nextstate, f);
    439       arc.nextstate = FindState(tuple);
    440     } else {
    441       arc.ilabel = arca.ilabel;
    442       arc.olabel = arcb.olabel;
    443       arc.weight = Times(arca.weight, arcb.weight);
    444       StateTuple tuple(arca.nextstate, arcb.nextstate, f);
    445       arc.nextstate = FindState(tuple);
    446     }
    447     CacheImpl<A>::AddArc(s, arc);
    448   }
    449 
    450   // Arranges it so that the first arg to OrderedExpand is the Fst
    451   // that will be passed to FindLabel.
    452   void Expand(StateId s) {
    453     StateTuple &tuple = state_tuples_[s];
    454     StateId s1 = tuple.state_id1;
    455     StateId s2 = tuple.state_id2;
    456     int f = tuple.filt;
    457     if (find_type_ == FIND_INPUT)
    458       OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
    459                     ComposeFstImplBase<A>::fst1_, s1, f, true);
    460     else
    461       OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
    462                     ComposeFstImplBase<A>::fst2_, s2, f, false);
    463   }
    464 
    465   // Access to flags encoding compose options/optimizations etc.  (for
    466   // debugging).
    467   virtual uint64 ComposeFlags() const { return T; }
    468 
    469  private:
    470   // This does that actual matching of labels in the composition. The
    471   // arguments are ordered so FindLabel is called with state SA of
    472   // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
    473   // determines whether the input or output label of arcs at SB is
    474   // the one to match on.
    475   void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
    476                      const Fst<A> *fstb, StateId sb, int f, bool find_input) {
    477 
    478     size_t numarcsa = fsta->NumArcs(sa);
    479     size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
    480                      fsta->NumOutputEpsilons(sa);
    481     bool finala = fsta->Final(sa) != Weight::Zero();
    482     ArcIterator< Fst<A> > aitera(*fsta, sa);
    483     // First handle special epsilons and sigmas on FSTA
    484     for (; !aitera.Done(); aitera.Next()) {
    485       const A &arca = aitera.Value();
    486       Label match_labela = find_input ? arca.ilabel : arca.olabel;
    487       if (match_labela > 0) {
    488         break;
    489       }
    490       if ((T & COMPOSE_SIGMA) != 0 &&  match_labela == kSigmaLabel) {
    491         // Found a sigma? Match it against all (non-special) symbols
    492         // on side b.
    493         for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
    494              !aiterb.Done();
    495              aiterb.Next()) {
    496           const A &arcb = aiterb.Value();
    497           Label labelb = find_input ? arcb.olabel : arcb.ilabel;
    498           if (labelb <= 0) continue;
    499           AddArc(s, arca, arcb, 0, find_input);
    500         }
    501       } else if (f == 0 && match_labela == 0) {
    502         A earcb(0, 0, Weight::One(), sb);
    503         AddArc(s, arca, earcb, 0, find_input);  // move forward on epsilon
    504       }
    505     }
    506     // Next handle non-epsilon matches, rho labels, and epsilons on FSTB
    507     for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
    508          !aiterb.Done();
    509          aiterb.Next()) {
    510       const A &arcb = aiterb.Value();
    511       Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
    512       if (match_labelb) {  // Consider non-epsilon match
    513         if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
    514           for (; !aitera.Done(); aitera.Next()) {
    515             const A &arca = aitera.Value();
    516             Label match_labela = find_input ? arca.ilabel : arca.olabel;
    517             if (match_labela != match_labelb)
    518               break;
    519             AddArc(s, arca, arcb, 0, find_input);  // move forward on match
    520           }
    521         } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
    522           // If there is no transition labelled 'match_labelb' in
    523           // fsta, try matching 'match_labelb' against special symbols
    524           // (Phi, Rho,...).
    525           for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
    526             A arca = aitera.Value();
    527             Label labela = find_input ? arca.ilabel : arca.olabel;
    528             if (labela >= 0) {
    529               break;
    530             } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
    531               // Case 1: if a failure transition exists, follow its
    532               // transitive closure until a) a transition labelled
    533               // 'match_labelb' is found, or b) the initial state of
    534               // fsta is reached.
    535 
    536               StateId sf = sa;  // Start of current failure transition.
    537               while (labela == kPhiLabel && sf != arca.nextstate) {
    538                 sf = arca.nextstate;
    539 
    540                 size_t numarcsf = fsta->NumArcs(sf);
    541                 ArcIterator< Fst<A> > aiterf(*fsta, sf);
    542                 if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
    543                   // Sub-case 1a: there exists a transition starting
    544                   // in sf and consuming symbol 'match_labelb'.
    545                   AddArc(s, aiterf.Value(), arcb, 0, find_input);
    546                   break;
    547                 } else {
    548                   // No transition labelled 'match_labelb' found: try
    549                   // next failure transition (starting at 'sf').
    550                   for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
    551                     arca = aiterf.Value();
    552                     labela = find_input ? arca.ilabel : arca.olabel;
    553                     if (labela >= kPhiLabel) break;
    554                   }
    555                 }
    556               }
    557               if (labela == kPhiLabel && sf == arca.nextstate) {
    558                 // Sub-case 1b: failure transitions lead to start
    559                 // state without finding a matching
    560                 // transition. Therefore, we generate a loop in start
    561                 // state of fsta.
    562                 A loop(match_labelb, match_labelb, Weight::One(), sf);
    563                 AddArc(s, loop, arcb, 0, find_input);
    564               }
    565             } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
    566               // Case 2: 'match_labelb' can be matched against a
    567               // "rest" (rho) label in fsta.
    568               if (find_input) {
    569                 arca.ilabel = match_labelb;
    570                 if (arca.olabel == kRhoLabel)
    571                   arca.olabel = match_labelb;
    572               } else {
    573                 arca.olabel = match_labelb;
    574                 if (arca.ilabel == kRhoLabel)
    575                   arca.ilabel = match_labelb;
    576               }
    577               AddArc(s, arca, arcb, 0, find_input);  // move fwd on match
    578             }
    579           }
    580         }
    581       } else if (numepsa != numarcsa || finala) {  // Handle FSTB epsilon
    582         A earca(0, 0, Weight::One(), sa);
    583         AddArc(s, earca, arcb, numepsa > 0, find_input);  // move on epsilon
    584       }
    585     }
    586     this->SetArcs(s);
    587    }
    588 
    589 
    590   // Finds matches to MATCH_LABEL in arcs given by AITER
    591   // using FIND_INPUT to determine whether to look on input or output.
    592   bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
    593                  Label match_label, bool find_input) {
    594     // binary search for match
    595     size_t low = 0;
    596     size_t high = numarcs;
    597     while (low < high) {
    598       size_t mid = (low + high) / 2;
    599       aiter->Seek(mid);
    600       Label label = find_input ?
    601                     aiter->Value().ilabel : aiter->Value().olabel;
    602       if (label > match_label) {
    603         high = mid;
    604       } else if (label < match_label) {
    605         low = mid + 1;
    606       } else {
    607         // find first matching label (when non-determinism)
    608         for (size_t i = mid; i > low; --i) {
    609           aiter->Seek(i - 1);
    610           label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
    611           if (label != match_label) {
    612             aiter->Seek(i);
    613             return true;
    614           }
    615         }
    616         return true;
    617       }
    618     }
    619     return false;
    620   }
    621 
    622   StateId ComputeStart() {
    623     StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
    624     StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
    625     if (s1 == kNoStateId || s2 == kNoStateId)
    626       return kNoStateId;
    627     StateTuple tuple(s1, s2, 0);
    628     return FindState(tuple);
    629   }
    630 
    631   Weight ComputeFinal(StateId s) {
    632     StateTuple &tuple = state_tuples_[s];
    633     Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
    634                          ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
    635     return final;
    636   }
    637 
    638 
    639   FindType find_type_;            // find label on which side?
    640 
    641   // Maps from StateId to StateTuple.
    642   vector<StateTuple> state_tuples_;
    643 
    644   // Maps from StateTuple to StateId.
    645   StateTupleTable state_tuple_table_;
    646 
    647   DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
    648 };
    649 
    650 
    651 // Computes the composition of two transducers. This version is a
    652 // delayed Fst. If FST1 transduces string x to y with weight a and FST2
    653 // transduces y to z with weight b, then their composition transduces
    654 // string x to z with weight Times(x, z).
    655 //
    656 // The output labels of the first transducer or the input labels of
    657 // the second transducer must be sorted.  The weights need to form a
    658 // commutative semiring (valid for TropicalWeight and LogWeight).
    659 //
    660 // Complexity:
    661 // Assuming the first FST is unsorted and the second is sorted:
    662 // - Time: O(v1 v2 d1 (log d2 + m2)),
    663 // - Space: O(v1 v2)
    664 // where vi = # of states visited, di = maximum out-degree, and mi the
    665 // maximum multiplicity of the states visited for the ith
    666 // FST. Constant time and space to visit an input state or arc is
    667 // assumed and exclusive of caching.
    668 //
    669 // Caveats:
    670 // - ComposeFst does not trim its output (since it is a delayed operation).
    671 // - The efficiency of composition can be strongly affected by several factors:
    672 //   - the choice of which tnansducer is sorted - prefer sorting the FST
    673 //     that has the greater average out-degree.
    674 //   - the amount of non-determinism
    675 //   - the presence and location of epsilon transitions - avoid epsilon
    676 //     transitions on the output side of the first transducer or
    677 //     the input side of the second transducer or prefer placing
    678 //     them later in a path since they delay matching and can
    679 //     introduce non-coaccessible states and transitions.
    680 template <class A>
    681 class ComposeFst : public Fst<A> {
    682  public:
    683   friend class ArcIterator< ComposeFst<A> >;
    684   friend class CacheStateIterator< ComposeFst<A> >;
    685   friend class CacheArcIterator< ComposeFst<A> >;
    686 
    687   typedef A Arc;
    688   typedef typename A::Weight Weight;
    689   typedef typename A::StateId StateId;
    690   typedef CacheState<A> State;
    691 
    692   ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
    693       : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }
    694 
    695   template <uint64 T>
    696   ComposeFst(const Fst<A> &fst1,
    697              const Fst<A> &fst2,
    698              const ComposeFstOptions<T> &opts)
    699       : impl_(Init(fst1, fst2, opts)) { }
    700 
    701   ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
    702     impl_->IncrRefCount();
    703   }
    704 
    705   virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    706 
    707   virtual StateId Start() const { return impl_->Start(); }
    708 
    709   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    710 
    711   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    712 
    713   virtual size_t NumInputEpsilons(StateId s) const {
    714     return impl_->NumInputEpsilons(s);
    715   }
    716 
    717   virtual size_t NumOutputEpsilons(StateId s) const {
    718     return impl_->NumOutputEpsilons(s);
    719   }
    720 
    721   virtual uint64 Properties(uint64 mask, bool test) const {
    722     if (test) {
    723       uint64 known, test = TestProperties(*this, mask, &known);
    724       impl_->SetProperties(test, known);
    725       return test & mask;
    726     } else {
    727       return impl_->Properties(mask);
    728     }
    729   }
    730 
    731   virtual const string& Type() const { return impl_->Type(); }
    732 
    733   virtual ComposeFst<A> *Copy() const {
    734     return new ComposeFst<A>(*this);
    735   }
    736 
    737   virtual const SymbolTable* InputSymbols() const {
    738     return impl_->InputSymbols();
    739   }
    740 
    741   virtual const SymbolTable* OutputSymbols() const {
    742     return impl_->OutputSymbols();
    743   }
    744 
    745   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    746 
    747   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    748     impl_->InitArcIterator(s, data);
    749   }
    750 
    751   // Access to flags encoding compose options/optimizations etc.  (for
    752   // debugging).
    753   uint64 ComposeFlags() const { return impl_->ComposeFlags(); }
    754 
    755  protected:
    756   ComposeFstImplBase<A> *Impl() { return impl_; }
    757 
    758  private:
    759   ComposeFstImplBase<A> *impl_;
    760 
    761   // Auxiliary method encapsulating the creation of a ComposeFst
    762   // implementation that is appropriate for the properties of fst1 and
    763   // fst2.
    764   template <uint64 T>
    765   static ComposeFstImplBase<A> *Init(
    766       const Fst<A> &fst1,
    767       const Fst<A> &fst2,
    768       const ComposeFstOptions<T> &opts) {
    769 
    770     // Filter for sort properties (forces a property check).
    771     uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
    772     // Filter for optimization-related properties (does not force a
    773     // property-check).
    774     uint64 opt_props_mask =
    775       kString | kIDeterministic | kODeterministic | kNoIEpsilons |
    776       kNoOEpsilons;
    777 
    778     uint64 props1 = fst1.Properties(sort_props_mask, true);
    779     uint64 props2 = fst2.Properties(sort_props_mask, true);
    780 
    781     props1 |= fst1.Properties(opt_props_mask, false);
    782     props2 |= fst2.Properties(opt_props_mask, false);
    783 
    784     if (!(Weight::Properties() & kCommutative)) {
    785       props1 |= fst1.Properties(kUnweighted, true);
    786       props2 |= fst2.Properties(kUnweighted, true);
    787       if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
    788         LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
    789                    << Weight::Type();
    790     }
    791 
    792     // Case 1: flag COMPOSE_GENERIC disables optimizations.
    793     if (T & COMPOSE_GENERIC) {
    794       return new ComposeFstImpl<A, T>(fst1, fst2, opts);
    795     }
    796 
    797     const uint64 kStringDetOptProps =
    798       kIDeterministic | kILabelSorted | kNoIEpsilons;
    799     const uint64 kDetStringOptProps =
    800       kODeterministic | kOLabelSorted | kNoOEpsilons;
    801 
    802     // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
    803     if ((props1 & kString) &&
    804         !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
    805         ((props2 & kStringDetOptProps) == kStringDetOptProps)) {
    806       return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
    807           fst1, fst2, opts);
    808     }
    809     // Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
    810     if ((props2 & kString) &&
    811         !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
    812         ((props1 & kDetStringOptProps) == kDetStringOptProps)) {
    813       return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
    814           fst1, fst2, opts);
    815     }
    816 
    817     // Default case: no optimizations.
    818     return new ComposeFstImpl<A, T>(fst1, fst2, opts);
    819   }
    820 
    821   void operator=(const ComposeFst<A> &fst);  // disallow
    822 };
    823 
    824 
    825 // Specialization for ComposeFst.
    826 template<class A>
    827 class StateIterator< ComposeFst<A> >
    828     : public CacheStateIterator< ComposeFst<A> > {
    829  public:
    830   explicit StateIterator(const ComposeFst<A> &fst)
    831       : CacheStateIterator< ComposeFst<A> >(fst) {}
    832 };
    833 
    834 
    835 // Specialization for ComposeFst.
    836 template <class A>
    837 class ArcIterator< ComposeFst<A> >
    838     : public CacheArcIterator< ComposeFst<A> > {
    839  public:
    840   typedef typename A::StateId StateId;
    841 
    842   ArcIterator(const ComposeFst<A> &fst, StateId s)
    843       : CacheArcIterator< ComposeFst<A> >(fst, s) {
    844     if (!fst.impl_->HasArcs(s))
    845       fst.impl_->Expand(s);
    846   }
    847 
    848  private:
    849   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    850 };
    851 
    852 template <class A> inline
    853 void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    854   data->base = new StateIterator< ComposeFst<A> >(*this);
    855 }
    856 
    857 // Useful alias when using StdArc.
    858 typedef ComposeFst<StdArc> StdComposeFst;
    859 
    860 
    861 struct ComposeOptions {
    862   bool connect;  // Connect output
    863 
    864   ComposeOptions(bool c) : connect(c) {}
    865   ComposeOptions() : connect(true) { }
    866 };
    867 
    868 
    869 // Computes the composition of two transducers. This version writes
    870 // the composed FST into a MurableFst. If FST1 transduces string x to
    871 // y with weight a and FST2 transduces y to z with weight b, then
    872 // their composition transduces string x to z with weight
    873 // Times(x, z).
    874 //
    875 // The output labels of the first transducer or the input labels of
    876 // the second transducer must be sorted.  The weights need to form a
    877 // commutative semiring (valid for TropicalWeight and LogWeight).
    878 //
    879 // Complexity:
    880 // Assuming the first FST is unsorted and the second is sorted:
    881 // - Time: O(V1 V2 D1 (log D2 + M2)),
    882 // - Space: O(V1 V2 D1 M2)
    883 // where Vi = # of states, Di = maximum out-degree, and Mi is
    884 // the maximum multiplicity for the ith FST.
    885 //
    886 // Caveats:
    887 // - Compose trims its output.
    888 // - The efficiency of composition can be strongly affected by several factors:
    889 //   - the choice of which tnansducer is sorted - prefer sorting the FST
    890 //     that has the greater average out-degree.
    891 //   - the amount of non-determinism
    892 //   - the presence and location of epsilon transitions - avoid epsilon
    893 //     transitions on the output side of the first transducer or
    894 //     the input side of the second transducer or prefer placing
    895 //     them later in a path since they delay matching and can
    896 //     introduce non-coaccessible states and transitions.
    897 template<class Arc>
    898 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
    899              MutableFst<Arc> *ofst,
    900              const ComposeOptions &opts = ComposeOptions()) {
    901   ComposeFstOptions<> nopts;
    902   nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
    903   *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
    904   if (opts.connect)
    905     Connect(ofst);
    906 }
    907 
    908 }  // namespace fst
    909 
    910 #endif  // FST_LIB_COMPOSE_H__
    911