1 // compose.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Class to compute the composition of two FSTs 18 19 #ifndef FST_LIB_COMPOSE_H__ 20 #define FST_LIB_COMPOSE_H__ 21 22 #include <algorithm> 23 24 #include <ext/hash_map> 25 using __gnu_cxx::hash_map; 26 27 #include "fst/lib/cache.h" 28 #include "fst/lib/test-properties.h" 29 30 namespace fst { 31 32 // Enumeration of uint64 bits used to represent the user-defined 33 // properties of FST composition (in the template parameter to 34 // ComposeFstOptions<T>). The bits stand for extensions of generic FST 35 // composition. ComposeFstOptions<> (all the bits unset) is the "plain" 36 // compose without any extra extensions. 37 enum ComposeTypes { 38 // RHO: flags dealing with a special "rest" symbol in the FSTs. 39 // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO 40 // may be set. 41 COMPOSE_FST1_RHO = 1ULL<<0, // "Rest" symbol on the output side of fst1. 42 COMPOSE_FST2_RHO = 1ULL<<1, // "Rest" symbol on the input side of fst2. 43 COMPOSE_FST1_PHI = 1ULL<<2, // "Failure" symbol on the output 44 // side of fst1. 45 COMPOSE_FST2_PHI = 1ULL<<3, // "Failure" symbol on the input side 46 // of fst2. 47 COMPOSE_FST1_SIGMA = 1ULL<<4, // "Any" symbol on the output side of 48 // fst1. 49 COMPOSE_FST2_SIGMA = 1ULL<<5, // "Any" symbol on the input side of 50 // fst2. 51 // Optimization related bits. 52 COMPOSE_GENERIC = 1ULL<<32, // Disables optimizations, applies 53 // the generic version of the 54 // composition algorithm. This flag 55 // is used for internal testing 56 // only. 57 58 // ----------------------------------------------------------------- 59 // Auxiliary enum values denoting specific combinations of 60 // bits. Internal use only. 61 COMPOSE_RHO = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO, 62 COMPOSE_PHI = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI, 63 COMPOSE_SIGMA = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA, 64 COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA, 65 66 // ----------------------------------------------------------------- 67 // The following bits, denoting specific optimizations, are 68 // typically set *internally* by the composition algorithm. 69 COMPOSE_FST1_STRING = 1ULL<<33, // fst1 is a string 70 COMPOSE_FST2_STRING = 1ULL<<34, // fst2 is a string 71 COMPOSE_FST1_DET = 1ULL<<35, // fst1 is deterministic 72 COMPOSE_FST2_DET = 1ULL<<36, // fst2 is deterministic 73 COMPOSE_INTERNAL_MASK = 0xffffffff00000000ULL 74 }; 75 76 77 template <uint64 T = 0ULL> 78 struct ComposeFstOptions : public CacheOptions { 79 explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} 80 ComposeFstOptions() { } 81 }; 82 83 84 // Abstract base for the implementation of delayed ComposeFst. The 85 // concrete specializations are templated on the (uint64-valued) 86 // properties of the FSTs being composed. 87 template <class A> 88 class ComposeFstImplBase : public CacheImpl<A> { 89 public: 90 using FstImpl<A>::SetType; 91 using FstImpl<A>::SetProperties; 92 using FstImpl<A>::Properties; 93 using FstImpl<A>::SetInputSymbols; 94 using FstImpl<A>::SetOutputSymbols; 95 96 using CacheBaseImpl< CacheState<A> >::HasStart; 97 using CacheBaseImpl< CacheState<A> >::HasFinal; 98 using CacheBaseImpl< CacheState<A> >::HasArcs; 99 100 typedef typename A::Label Label; 101 typedef typename A::Weight Weight; 102 typedef typename A::StateId StateId; 103 typedef CacheState<A> State; 104 105 ComposeFstImplBase(const Fst<A> &fst1, 106 const Fst<A> &fst2, 107 const CacheOptions &opts) 108 :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) { 109 SetType("compose"); 110 uint64 props1 = fst1.Properties(kFstProperties, false); 111 uint64 props2 = fst2.Properties(kFstProperties, false); 112 SetProperties(ComposeProperties(props1, props2), kCopyProperties); 113 114 if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) 115 LOG(FATAL) << "ComposeFst: output symbol table of 1st argument " 116 << "does not match input symbol table of 2nd argument"; 117 118 SetInputSymbols(fst1.InputSymbols()); 119 SetOutputSymbols(fst2.OutputSymbols()); 120 } 121 122 virtual ~ComposeFstImplBase() { 123 delete fst1_; 124 delete fst2_; 125 } 126 127 StateId Start() { 128 if (!HasStart()) { 129 StateId start = ComputeStart(); 130 if (start != kNoStateId) { 131 this->SetStart(start); 132 } 133 } 134 return CacheImpl<A>::Start(); 135 } 136 137 Weight Final(StateId s) { 138 if (!HasFinal(s)) { 139 Weight final = ComputeFinal(s); 140 this->SetFinal(s, final); 141 } 142 return CacheImpl<A>::Final(s); 143 } 144 145 virtual void Expand(StateId s) = 0; 146 147 size_t NumArcs(StateId s) { 148 if (!HasArcs(s)) 149 Expand(s); 150 return CacheImpl<A>::NumArcs(s); 151 } 152 153 size_t NumInputEpsilons(StateId s) { 154 if (!HasArcs(s)) 155 Expand(s); 156 return CacheImpl<A>::NumInputEpsilons(s); 157 } 158 159 size_t NumOutputEpsilons(StateId s) { 160 if (!HasArcs(s)) 161 Expand(s); 162 return CacheImpl<A>::NumOutputEpsilons(s); 163 } 164 165 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 166 if (!HasArcs(s)) 167 Expand(s); 168 CacheImpl<A>::InitArcIterator(s, data); 169 } 170 171 // Access to flags encoding compose options/optimizations etc. (for 172 // debugging). 173 virtual uint64 ComposeFlags() const = 0; 174 175 protected: 176 virtual StateId ComputeStart() = 0; 177 virtual Weight ComputeFinal(StateId s) = 0; 178 179 const Fst<A> *fst1_; // first input Fst 180 const Fst<A> *fst2_; // second input Fst 181 }; 182 183 184 // The following class encapsulates implementation-dependent details 185 // of state tuple lookup, i.e. a bijective mapping from triples of two 186 // FST states and an epsilon filter state to the corresponding state 187 // IDs of the fst resulting from composition. The mapping must 188 // implement the [] operator in the style of STL associative 189 // containers (map, hash_map), i.e. table[x] must return a reference 190 // to the value associated with x. If x is an unassigned tuple, the 191 // operator must automatically associate x with value 0. 192 // 193 // NB: "table[x] == 0" for unassigned tuples x is required by the 194 // following off-by-one device used in the implementation of 195 // ComposeFstImpl. The value stored in the table is equal to tuple ID 196 // plus one, i.e. it is always a strictly positive number. Therefore, 197 // table[x] is equal to 0 if and only if x is an unassigned tuple (in 198 // which the algorithm assigns a new ID to x, and sets table[x] - 199 // stored in a reference - to "new ID + 1"). This form of lookup is 200 // more efficient than calling "find(x)" and "insert(make_pair(x, new 201 // ID))" if x is an unassigned tuple. 202 // 203 // The generic implementation is a wrapper around a hash_map. 204 template <class A, uint64 T> 205 class ComposeStateTable { 206 public: 207 typedef typename A::StateId StateId; 208 209 struct StateTuple { 210 StateTuple() {} 211 StateTuple(StateId s1, StateId s2, int f) 212 : state_id1(s1), state_id2(s2), filt(f) {} 213 StateId state_id1; // state Id on fst1 214 StateId state_id2; // state Id on fst2 215 int filt; // epsilon filter state 216 }; 217 218 ComposeStateTable() { 219 StateTuple empty_tuple(kNoStateId, kNoStateId, 0); 220 } 221 222 // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is 223 // inserted into 'table_' (standard STL container semantics). Since 224 // StateId is a built-in type, the explicit default constructor call 225 // StateId() returns 0. 226 StateId &operator[](const StateTuple &tuple) { 227 return table_[tuple]; 228 } 229 230 private: 231 // Comparison object for hashing StateTuple(s). 232 class StateTupleEqual { 233 public: 234 bool operator()(const StateTuple& x, const StateTuple& y) const { 235 return x.state_id1 == y.state_id1 && 236 x.state_id2 == y.state_id2 && 237 x.filt == y.filt; 238 } 239 }; 240 241 static const int kPrime0 = 7853; 242 static const int kPrime1 = 7867; 243 244 // Hash function for StateTuple to Fst states. 245 class StateTupleKey { 246 public: 247 size_t operator()(const StateTuple& x) const { 248 return static_cast<size_t>(x.state_id1 + 249 x.state_id2 * kPrime0 + 250 x.filt * kPrime1); 251 } 252 }; 253 254 // Lookup table mapping state tuples to state IDs. 255 typedef hash_map<StateTuple, 256 StateId, 257 StateTupleKey, 258 StateTupleEqual> StateTable; 259 // Actual table data. 260 StateTable table_; 261 262 DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable); 263 }; 264 265 266 // State tuple lookup table for the composition of a string FST with a 267 // deterministic FST. The class maps state tuples to their unique IDs 268 // (i.e. states of the ComposeFst). Main optimization: due to the 269 // 1-to-1 correspondence between the states of the input string FST 270 // and those of the resulting (string) FST, a state tuple (s1, s2) is 271 // simply mapped to StateId s1. Hence, we use an STL vector as a 272 // lookup table. Template argument Fst1IsString specifies which FST is 273 // a string (this determines whether or not we index the lookup table 274 // by the first or by the second state). 275 template <class A, bool Fst1IsString> 276 class StringDetComposeStateTable { 277 public: 278 typedef typename A::StateId StateId; 279 280 struct StateTuple { 281 typedef typename A::StateId StateId; 282 StateTuple() {} 283 StateTuple(StateId s1, StateId s2, int /* f */) 284 : state_id1(s1), state_id2(s2) {} 285 StateId state_id1; // state Id on fst1 286 StateId state_id2; // state Id on fst2 287 static const int filt = 0; // 'fake' epsilon filter - only needed 288 // for API compatibility 289 }; 290 291 StringDetComposeStateTable() {} 292 293 // Subscript operator. Behaves in a way similar to its map/hash_map 294 // counterpart, i.e. returns a reference to the value associated 295 // with 'tuple', inserting a 0 value if 'tuple' is unassigned. 296 StateId &operator[](const StateTuple &tuple) { 297 StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2; 298 if (index >= (StateId)data_.size()) { 299 // NB: all values in [old_size; index] are initialized to 0. 300 data_.resize(index + 1); 301 } 302 return data_[index]; 303 } 304 305 private: 306 vector<StateId> data_; 307 308 DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable); 309 }; 310 311 312 // Specializations of ComposeStateTable for the string/det case. 313 // Both inherit from StringDetComposeStateTable. 314 template <class A> 315 class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET> 316 : public StringDetComposeStateTable<A, true> { }; 317 318 template <class A> 319 class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET> 320 : public StringDetComposeStateTable<A, false> { }; 321 322 323 // Parameterized implementation of FST composition for a pair of FSTs 324 // matching the property bit vector T. If possible, 325 // instantiation-specific switches in the code are based on the values 326 // of the bits in T, which are known at compile time, so unused code 327 // should be optimized away by the compiler. 328 template <class A, uint64 T> 329 class ComposeFstImpl : public ComposeFstImplBase<A> { 330 typedef typename A::StateId StateId; 331 typedef typename A::Label Label; 332 typedef typename A::Weight Weight; 333 using FstImpl<A>::SetType; 334 using FstImpl<A>::SetProperties; 335 336 enum FindType { FIND_INPUT = 1, // find input label on fst2 337 FIND_OUTPUT = 2, // find output label on fst1 338 FIND_BOTH = 3 }; // find choice state dependent 339 340 typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable; 341 typedef typename StateTupleTable::StateTuple StateTuple; 342 343 public: 344 ComposeFstImpl(const Fst<A> &fst1, 345 const Fst<A> &fst2, 346 const CacheOptions &opts) 347 :ComposeFstImplBase<A>(fst1, fst2, opts) { 348 349 bool osorted = fst1.Properties(kOLabelSorted, false); 350 bool isorted = fst2.Properties(kILabelSorted, false); 351 352 switch (T & COMPOSE_SPECIAL_SYMBOLS) { 353 case COMPOSE_FST1_RHO: 354 case COMPOSE_FST1_PHI: 355 case COMPOSE_FST1_SIGMA: 356 if (!osorted || FLAGS_fst_verify_properties) 357 osorted = fst1.Properties(kOLabelSorted, true); 358 if (!osorted) 359 LOG(FATAL) << "ComposeFst: 1st argument not output label " 360 << "sorted (special symbols present)"; 361 break; 362 case COMPOSE_FST2_RHO: 363 case COMPOSE_FST2_PHI: 364 case COMPOSE_FST2_SIGMA: 365 if (!isorted || FLAGS_fst_verify_properties) 366 isorted = fst2.Properties(kILabelSorted, true); 367 if (!isorted) 368 LOG(FATAL) << "ComposeFst: 2nd argument not input label " 369 << "sorted (special symbols present)"; 370 break; 371 case 0: 372 if (!isorted && !osorted || FLAGS_fst_verify_properties) { 373 osorted = fst1.Properties(kOLabelSorted, true); 374 if (!osorted) 375 isorted = fst2.Properties(kILabelSorted, true); 376 } 377 break; 378 default: 379 LOG(FATAL) 380 << "ComposeFst: More than one special symbol used in composition"; 381 } 382 383 if (isorted && (T & COMPOSE_FST2_SIGMA)) { 384 find_type_ = FIND_INPUT; 385 } else if (osorted && (T & COMPOSE_FST1_SIGMA)) { 386 find_type_ = FIND_OUTPUT; 387 } else if (isorted && (T & COMPOSE_FST2_PHI)) { 388 find_type_ = FIND_INPUT; 389 } else if (osorted && (T & COMPOSE_FST1_PHI)) { 390 find_type_ = FIND_OUTPUT; 391 } else if (isorted && (T & COMPOSE_FST2_RHO)) { 392 find_type_ = FIND_INPUT; 393 } else if (osorted && (T & COMPOSE_FST1_RHO)) { 394 find_type_ = FIND_OUTPUT; 395 } else if (isorted && (T & COMPOSE_FST1_STRING)) { 396 find_type_ = FIND_INPUT; 397 } else if(osorted && (T & COMPOSE_FST2_STRING)) { 398 find_type_ = FIND_OUTPUT; 399 } else if (isorted && osorted) { 400 find_type_ = FIND_BOTH; 401 } else if (isorted) { 402 find_type_ = FIND_INPUT; 403 } else if (osorted) { 404 find_type_ = FIND_OUTPUT; 405 } else { 406 LOG(FATAL) << "ComposeFst: 1st argument not output label sorted " 407 << "and 2nd argument is not input label sorted"; 408 } 409 } 410 411 // Finds/creates an Fst state given a StateTuple. Only creates a new 412 // state if StateTuple is not found in the state hash. 413 // 414 // The method exploits the following device: all pairs stored in the 415 // associative container state_tuple_table_ are of the form (tuple, 416 // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has 417 // been stored previously. For unassigned tuples, the call to 418 // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a 419 // result, state_tuple_table_[tuple] == 0 iff tuple is new. 420 StateId FindState(const StateTuple& tuple) { 421 StateId &assoc_value = state_tuple_table_[tuple]; 422 if (assoc_value == 0) { // tuple wasn't present in lookup table: 423 // assign it a new ID. 424 state_tuples_.push_back(tuple); 425 assoc_value = state_tuples_.size(); 426 } 427 return assoc_value - 1; // NB: assoc_value = ID + 1 428 } 429 430 // Generates arc for composition state s from matched input Fst arcs. 431 void AddArc(StateId s, const A &arca, const A &arcb, int f, 432 bool find_input) { 433 A arc; 434 if (find_input) { 435 arc.ilabel = arcb.ilabel; 436 arc.olabel = arca.olabel; 437 arc.weight = Times(arcb.weight, arca.weight); 438 StateTuple tuple(arcb.nextstate, arca.nextstate, f); 439 arc.nextstate = FindState(tuple); 440 } else { 441 arc.ilabel = arca.ilabel; 442 arc.olabel = arcb.olabel; 443 arc.weight = Times(arca.weight, arcb.weight); 444 StateTuple tuple(arca.nextstate, arcb.nextstate, f); 445 arc.nextstate = FindState(tuple); 446 } 447 CacheImpl<A>::AddArc(s, arc); 448 } 449 450 // Arranges it so that the first arg to OrderedExpand is the Fst 451 // that will be passed to FindLabel. 452 void Expand(StateId s) { 453 StateTuple &tuple = state_tuples_[s]; 454 StateId s1 = tuple.state_id1; 455 StateId s2 = tuple.state_id2; 456 int f = tuple.filt; 457 if (find_type_ == FIND_INPUT) 458 OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2, 459 ComposeFstImplBase<A>::fst1_, s1, f, true); 460 else 461 OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1, 462 ComposeFstImplBase<A>::fst2_, s2, f, false); 463 } 464 465 // Access to flags encoding compose options/optimizations etc. (for 466 // debugging). 467 virtual uint64 ComposeFlags() const { return T; } 468 469 private: 470 // This does that actual matching of labels in the composition. The 471 // arguments are ordered so FindLabel is called with state SA of 472 // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg 473 // determines whether the input or output label of arcs at SB is 474 // the one to match on. 475 void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa, 476 const Fst<A> *fstb, StateId sb, int f, bool find_input) { 477 478 size_t numarcsa = fsta->NumArcs(sa); 479 size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) : 480 fsta->NumOutputEpsilons(sa); 481 bool finala = fsta->Final(sa) != Weight::Zero(); 482 ArcIterator< Fst<A> > aitera(*fsta, sa); 483 // First handle special epsilons and sigmas on FSTA 484 for (; !aitera.Done(); aitera.Next()) { 485 const A &arca = aitera.Value(); 486 Label match_labela = find_input ? arca.ilabel : arca.olabel; 487 if (match_labela > 0) { 488 break; 489 } 490 if ((T & COMPOSE_SIGMA) != 0 && match_labela == kSigmaLabel) { 491 // Found a sigma? Match it against all (non-special) symbols 492 // on side b. 493 for (ArcIterator< Fst<A> > aiterb(*fstb, sb); 494 !aiterb.Done(); 495 aiterb.Next()) { 496 const A &arcb = aiterb.Value(); 497 Label labelb = find_input ? arcb.olabel : arcb.ilabel; 498 if (labelb <= 0) continue; 499 AddArc(s, arca, arcb, 0, find_input); 500 } 501 } else if (f == 0 && match_labela == 0) { 502 A earcb(0, 0, Weight::One(), sb); 503 AddArc(s, arca, earcb, 0, find_input); // move forward on epsilon 504 } 505 } 506 // Next handle non-epsilon matches, rho labels, and epsilons on FSTB 507 for (ArcIterator< Fst<A> > aiterb(*fstb, sb); 508 !aiterb.Done(); 509 aiterb.Next()) { 510 const A &arcb = aiterb.Value(); 511 Label match_labelb = find_input ? arcb.olabel : arcb.ilabel; 512 if (match_labelb) { // Consider non-epsilon match 513 if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) { 514 for (; !aitera.Done(); aitera.Next()) { 515 const A &arca = aitera.Value(); 516 Label match_labela = find_input ? arca.ilabel : arca.olabel; 517 if (match_labela != match_labelb) 518 break; 519 AddArc(s, arca, arcb, 0, find_input); // move forward on match 520 } 521 } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) { 522 // If there is no transition labelled 'match_labelb' in 523 // fsta, try matching 'match_labelb' against special symbols 524 // (Phi, Rho,...). 525 for (aitera.Reset(); !aitera.Done(); aitera.Next()) { 526 A arca = aitera.Value(); 527 Label labela = find_input ? arca.ilabel : arca.olabel; 528 if (labela >= 0) { 529 break; 530 } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) { 531 // Case 1: if a failure transition exists, follow its 532 // transitive closure until a) a transition labelled 533 // 'match_labelb' is found, or b) the initial state of 534 // fsta is reached. 535 536 StateId sf = sa; // Start of current failure transition. 537 while (labela == kPhiLabel && sf != arca.nextstate) { 538 sf = arca.nextstate; 539 540 size_t numarcsf = fsta->NumArcs(sf); 541 ArcIterator< Fst<A> > aiterf(*fsta, sf); 542 if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) { 543 // Sub-case 1a: there exists a transition starting 544 // in sf and consuming symbol 'match_labelb'. 545 AddArc(s, aiterf.Value(), arcb, 0, find_input); 546 break; 547 } else { 548 // No transition labelled 'match_labelb' found: try 549 // next failure transition (starting at 'sf'). 550 for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) { 551 arca = aiterf.Value(); 552 labela = find_input ? arca.ilabel : arca.olabel; 553 if (labela >= kPhiLabel) break; 554 } 555 } 556 } 557 if (labela == kPhiLabel && sf == arca.nextstate) { 558 // Sub-case 1b: failure transitions lead to start 559 // state without finding a matching 560 // transition. Therefore, we generate a loop in start 561 // state of fsta. 562 A loop(match_labelb, match_labelb, Weight::One(), sf); 563 AddArc(s, loop, arcb, 0, find_input); 564 } 565 } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) { 566 // Case 2: 'match_labelb' can be matched against a 567 // "rest" (rho) label in fsta. 568 if (find_input) { 569 arca.ilabel = match_labelb; 570 if (arca.olabel == kRhoLabel) 571 arca.olabel = match_labelb; 572 } else { 573 arca.olabel = match_labelb; 574 if (arca.ilabel == kRhoLabel) 575 arca.ilabel = match_labelb; 576 } 577 AddArc(s, arca, arcb, 0, find_input); // move fwd on match 578 } 579 } 580 } 581 } else if (numepsa != numarcsa || finala) { // Handle FSTB epsilon 582 A earca(0, 0, Weight::One(), sa); 583 AddArc(s, earca, arcb, numepsa > 0, find_input); // move on epsilon 584 } 585 } 586 this->SetArcs(s); 587 } 588 589 590 // Finds matches to MATCH_LABEL in arcs given by AITER 591 // using FIND_INPUT to determine whether to look on input or output. 592 bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs, 593 Label match_label, bool find_input) { 594 // binary search for match 595 size_t low = 0; 596 size_t high = numarcs; 597 while (low < high) { 598 size_t mid = (low + high) / 2; 599 aiter->Seek(mid); 600 Label label = find_input ? 601 aiter->Value().ilabel : aiter->Value().olabel; 602 if (label > match_label) { 603 high = mid; 604 } else if (label < match_label) { 605 low = mid + 1; 606 } else { 607 // find first matching label (when non-determinism) 608 for (size_t i = mid; i > low; --i) { 609 aiter->Seek(i - 1); 610 label = find_input ? aiter->Value().ilabel : aiter->Value().olabel; 611 if (label != match_label) { 612 aiter->Seek(i); 613 return true; 614 } 615 } 616 return true; 617 } 618 } 619 return false; 620 } 621 622 StateId ComputeStart() { 623 StateId s1 = ComposeFstImplBase<A>::fst1_->Start(); 624 StateId s2 = ComposeFstImplBase<A>::fst2_->Start(); 625 if (s1 == kNoStateId || s2 == kNoStateId) 626 return kNoStateId; 627 StateTuple tuple(s1, s2, 0); 628 return FindState(tuple); 629 } 630 631 Weight ComputeFinal(StateId s) { 632 StateTuple &tuple = state_tuples_[s]; 633 Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1), 634 ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2)); 635 return final; 636 } 637 638 639 FindType find_type_; // find label on which side? 640 641 // Maps from StateId to StateTuple. 642 vector<StateTuple> state_tuples_; 643 644 // Maps from StateTuple to StateId. 645 StateTupleTable state_tuple_table_; 646 647 DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl); 648 }; 649 650 651 // Computes the composition of two transducers. This version is a 652 // delayed Fst. If FST1 transduces string x to y with weight a and FST2 653 // transduces y to z with weight b, then their composition transduces 654 // string x to z with weight Times(x, z). 655 // 656 // The output labels of the first transducer or the input labels of 657 // the second transducer must be sorted. The weights need to form a 658 // commutative semiring (valid for TropicalWeight and LogWeight). 659 // 660 // Complexity: 661 // Assuming the first FST is unsorted and the second is sorted: 662 // - Time: O(v1 v2 d1 (log d2 + m2)), 663 // - Space: O(v1 v2) 664 // where vi = # of states visited, di = maximum out-degree, and mi the 665 // maximum multiplicity of the states visited for the ith 666 // FST. Constant time and space to visit an input state or arc is 667 // assumed and exclusive of caching. 668 // 669 // Caveats: 670 // - ComposeFst does not trim its output (since it is a delayed operation). 671 // - The efficiency of composition can be strongly affected by several factors: 672 // - the choice of which tnansducer is sorted - prefer sorting the FST 673 // that has the greater average out-degree. 674 // - the amount of non-determinism 675 // - the presence and location of epsilon transitions - avoid epsilon 676 // transitions on the output side of the first transducer or 677 // the input side of the second transducer or prefer placing 678 // them later in a path since they delay matching and can 679 // introduce non-coaccessible states and transitions. 680 template <class A> 681 class ComposeFst : public Fst<A> { 682 public: 683 friend class ArcIterator< ComposeFst<A> >; 684 friend class CacheStateIterator< ComposeFst<A> >; 685 friend class CacheArcIterator< ComposeFst<A> >; 686 687 typedef A Arc; 688 typedef typename A::Weight Weight; 689 typedef typename A::StateId StateId; 690 typedef CacheState<A> State; 691 692 ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2) 693 : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { } 694 695 template <uint64 T> 696 ComposeFst(const Fst<A> &fst1, 697 const Fst<A> &fst2, 698 const ComposeFstOptions<T> &opts) 699 : impl_(Init(fst1, fst2, opts)) { } 700 701 ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) { 702 impl_->IncrRefCount(); 703 } 704 705 virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_; } 706 707 virtual StateId Start() const { return impl_->Start(); } 708 709 virtual Weight Final(StateId s) const { return impl_->Final(s); } 710 711 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 712 713 virtual size_t NumInputEpsilons(StateId s) const { 714 return impl_->NumInputEpsilons(s); 715 } 716 717 virtual size_t NumOutputEpsilons(StateId s) const { 718 return impl_->NumOutputEpsilons(s); 719 } 720 721 virtual uint64 Properties(uint64 mask, bool test) const { 722 if (test) { 723 uint64 known, test = TestProperties(*this, mask, &known); 724 impl_->SetProperties(test, known); 725 return test & mask; 726 } else { 727 return impl_->Properties(mask); 728 } 729 } 730 731 virtual const string& Type() const { return impl_->Type(); } 732 733 virtual ComposeFst<A> *Copy() const { 734 return new ComposeFst<A>(*this); 735 } 736 737 virtual const SymbolTable* InputSymbols() const { 738 return impl_->InputSymbols(); 739 } 740 741 virtual const SymbolTable* OutputSymbols() const { 742 return impl_->OutputSymbols(); 743 } 744 745 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 746 747 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 748 impl_->InitArcIterator(s, data); 749 } 750 751 // Access to flags encoding compose options/optimizations etc. (for 752 // debugging). 753 uint64 ComposeFlags() const { return impl_->ComposeFlags(); } 754 755 protected: 756 ComposeFstImplBase<A> *Impl() { return impl_; } 757 758 private: 759 ComposeFstImplBase<A> *impl_; 760 761 // Auxiliary method encapsulating the creation of a ComposeFst 762 // implementation that is appropriate for the properties of fst1 and 763 // fst2. 764 template <uint64 T> 765 static ComposeFstImplBase<A> *Init( 766 const Fst<A> &fst1, 767 const Fst<A> &fst2, 768 const ComposeFstOptions<T> &opts) { 769 770 // Filter for sort properties (forces a property check). 771 uint64 sort_props_mask = kILabelSorted | kOLabelSorted; 772 // Filter for optimization-related properties (does not force a 773 // property-check). 774 uint64 opt_props_mask = 775 kString | kIDeterministic | kODeterministic | kNoIEpsilons | 776 kNoOEpsilons; 777 778 uint64 props1 = fst1.Properties(sort_props_mask, true); 779 uint64 props2 = fst2.Properties(sort_props_mask, true); 780 781 props1 |= fst1.Properties(opt_props_mask, false); 782 props2 |= fst2.Properties(opt_props_mask, false); 783 784 if (!(Weight::Properties() & kCommutative)) { 785 props1 |= fst1.Properties(kUnweighted, true); 786 props2 |= fst2.Properties(kUnweighted, true); 787 if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) 788 LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: " 789 << Weight::Type(); 790 } 791 792 // Case 1: flag COMPOSE_GENERIC disables optimizations. 793 if (T & COMPOSE_GENERIC) { 794 return new ComposeFstImpl<A, T>(fst1, fst2, opts); 795 } 796 797 const uint64 kStringDetOptProps = 798 kIDeterministic | kILabelSorted | kNoIEpsilons; 799 const uint64 kDetStringOptProps = 800 kODeterministic | kOLabelSorted | kNoOEpsilons; 801 802 // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free. 803 if ((props1 & kString) && 804 !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && 805 ((props2 & kStringDetOptProps) == kStringDetOptProps)) { 806 return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>( 807 fst1, fst2, opts); 808 } 809 // Case 3: fst1 is deterministic and epsilon-free, fst2 is string. 810 if ((props2 & kString) && 811 !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && 812 ((props1 & kDetStringOptProps) == kDetStringOptProps)) { 813 return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>( 814 fst1, fst2, opts); 815 } 816 817 // Default case: no optimizations. 818 return new ComposeFstImpl<A, T>(fst1, fst2, opts); 819 } 820 821 void operator=(const ComposeFst<A> &fst); // disallow 822 }; 823 824 825 // Specialization for ComposeFst. 826 template<class A> 827 class StateIterator< ComposeFst<A> > 828 : public CacheStateIterator< ComposeFst<A> > { 829 public: 830 explicit StateIterator(const ComposeFst<A> &fst) 831 : CacheStateIterator< ComposeFst<A> >(fst) {} 832 }; 833 834 835 // Specialization for ComposeFst. 836 template <class A> 837 class ArcIterator< ComposeFst<A> > 838 : public CacheArcIterator< ComposeFst<A> > { 839 public: 840 typedef typename A::StateId StateId; 841 842 ArcIterator(const ComposeFst<A> &fst, StateId s) 843 : CacheArcIterator< ComposeFst<A> >(fst, s) { 844 if (!fst.impl_->HasArcs(s)) 845 fst.impl_->Expand(s); 846 } 847 848 private: 849 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 850 }; 851 852 template <class A> inline 853 void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 854 data->base = new StateIterator< ComposeFst<A> >(*this); 855 } 856 857 // Useful alias when using StdArc. 858 typedef ComposeFst<StdArc> StdComposeFst; 859 860 861 struct ComposeOptions { 862 bool connect; // Connect output 863 864 ComposeOptions(bool c) : connect(c) {} 865 ComposeOptions() : connect(true) { } 866 }; 867 868 869 // Computes the composition of two transducers. This version writes 870 // the composed FST into a MurableFst. If FST1 transduces string x to 871 // y with weight a and FST2 transduces y to z with weight b, then 872 // their composition transduces string x to z with weight 873 // Times(x, z). 874 // 875 // The output labels of the first transducer or the input labels of 876 // the second transducer must be sorted. The weights need to form a 877 // commutative semiring (valid for TropicalWeight and LogWeight). 878 // 879 // Complexity: 880 // Assuming the first FST is unsorted and the second is sorted: 881 // - Time: O(V1 V2 D1 (log D2 + M2)), 882 // - Space: O(V1 V2 D1 M2) 883 // where Vi = # of states, Di = maximum out-degree, and Mi is 884 // the maximum multiplicity for the ith FST. 885 // 886 // Caveats: 887 // - Compose trims its output. 888 // - The efficiency of composition can be strongly affected by several factors: 889 // - the choice of which tnansducer is sorted - prefer sorting the FST 890 // that has the greater average out-degree. 891 // - the amount of non-determinism 892 // - the presence and location of epsilon transitions - avoid epsilon 893 // transitions on the output side of the first transducer or 894 // the input side of the second transducer or prefer placing 895 // them later in a path since they delay matching and can 896 // introduce non-coaccessible states and transitions. 897 template<class Arc> 898 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, 899 MutableFst<Arc> *ofst, 900 const ComposeOptions &opts = ComposeOptions()) { 901 ComposeFstOptions<> nopts; 902 nopts.gc_limit = 0; // Cache only the last state for fastest copy. 903 *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts); 904 if (opts.connect) 905 Connect(ofst); 906 } 907 908 } // namespace fst 909 910 #endif // FST_LIB_COMPOSE_H__ 911