1 // compose.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // 16 // \file 17 // Class to compute the composition of two FSTs 18 19 #ifndef FST_LIB_COMPOSE_H__ 20 #define FST_LIB_COMPOSE_H__ 21 22 #include <algorithm> 23 24 #include <unordered_map> 25 26 #include "fst/lib/cache.h" 27 #include "fst/lib/test-properties.h" 28 29 namespace fst { 30 31 // Enumeration of uint64 bits used to represent the user-defined 32 // properties of FST composition (in the template parameter to 33 // ComposeFstOptions<T>). The bits stand for extensions of generic FST 34 // composition. ComposeFstOptions<> (all the bits unset) is the "plain" 35 // compose without any extra extensions. 36 enum ComposeTypes { 37 // RHO: flags dealing with a special "rest" symbol in the FSTs. 38 // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO 39 // may be set. 40 COMPOSE_FST1_RHO = 1ULL<<0, // "Rest" symbol on the output side of fst1. 41 COMPOSE_FST2_RHO = 1ULL<<1, // "Rest" symbol on the input side of fst2. 42 COMPOSE_FST1_PHI = 1ULL<<2, // "Failure" symbol on the output 43 // side of fst1. 44 COMPOSE_FST2_PHI = 1ULL<<3, // "Failure" symbol on the input side 45 // of fst2. 46 COMPOSE_FST1_SIGMA = 1ULL<<4, // "Any" symbol on the output side of 47 // fst1. 48 COMPOSE_FST2_SIGMA = 1ULL<<5, // "Any" symbol on the input side of 49 // fst2. 50 // Optimization related bits. 51 COMPOSE_GENERIC = 1ULL<<32, // Disables optimizations, applies 52 // the generic version of the 53 // composition algorithm. This flag 54 // is used for internal testing 55 // only. 56 57 // ----------------------------------------------------------------- 58 // Auxiliary enum values denoting specific combinations of 59 // bits. Internal use only. 60 COMPOSE_RHO = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO, 61 COMPOSE_PHI = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI, 62 COMPOSE_SIGMA = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA, 63 COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA, 64 65 // ----------------------------------------------------------------- 66 // The following bits, denoting specific optimizations, are 67 // typically set *internally* by the composition algorithm. 68 COMPOSE_FST1_STRING = 1ULL<<33, // fst1 is a string 69 COMPOSE_FST2_STRING = 1ULL<<34, // fst2 is a string 70 COMPOSE_FST1_DET = 1ULL<<35, // fst1 is deterministic 71 COMPOSE_FST2_DET = 1ULL<<36, // fst2 is deterministic 72 COMPOSE_INTERNAL_MASK = 0xffffffff00000000ULL 73 }; 74 75 76 template <uint64 T = 0ULL> 77 struct ComposeFstOptions : public CacheOptions { 78 explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} 79 ComposeFstOptions() { } 80 }; 81 82 83 // Abstract base for the implementation of delayed ComposeFst. The 84 // concrete specializations are templated on the (uint64-valued) 85 // properties of the FSTs being composed. 86 template <class A> 87 class ComposeFstImplBase : public CacheImpl<A> { 88 public: 89 using FstImpl<A>::SetType; 90 using FstImpl<A>::SetProperties; 91 using FstImpl<A>::Properties; 92 using FstImpl<A>::SetInputSymbols; 93 using FstImpl<A>::SetOutputSymbols; 94 95 using CacheBaseImpl< CacheState<A> >::HasStart; 96 using CacheBaseImpl< CacheState<A> >::HasFinal; 97 using CacheBaseImpl< CacheState<A> >::HasArcs; 98 99 typedef typename A::Label Label; 100 typedef typename A::Weight Weight; 101 typedef typename A::StateId StateId; 102 typedef CacheState<A> State; 103 104 ComposeFstImplBase(const Fst<A> &fst1, 105 const Fst<A> &fst2, 106 const CacheOptions &opts) 107 :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) { 108 SetType("compose"); 109 uint64 props1 = fst1.Properties(kFstProperties, false); 110 uint64 props2 = fst2.Properties(kFstProperties, false); 111 SetProperties(ComposeProperties(props1, props2), kCopyProperties); 112 113 if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) 114 LOG(FATAL) << "ComposeFst: output symbol table of 1st argument " 115 << "does not match input symbol table of 2nd argument"; 116 117 SetInputSymbols(fst1.InputSymbols()); 118 SetOutputSymbols(fst2.OutputSymbols()); 119 } 120 121 virtual ~ComposeFstImplBase() { 122 delete fst1_; 123 delete fst2_; 124 } 125 126 StateId Start() { 127 if (!HasStart()) { 128 StateId start = ComputeStart(); 129 if (start != kNoStateId) { 130 this->SetStart(start); 131 } 132 } 133 return CacheImpl<A>::Start(); 134 } 135 136 Weight Final(StateId s) { 137 if (!HasFinal(s)) { 138 Weight final = ComputeFinal(s); 139 this->SetFinal(s, final); 140 } 141 return CacheImpl<A>::Final(s); 142 } 143 144 virtual void Expand(StateId s) = 0; 145 146 size_t NumArcs(StateId s) { 147 if (!HasArcs(s)) 148 Expand(s); 149 return CacheImpl<A>::NumArcs(s); 150 } 151 152 size_t NumInputEpsilons(StateId s) { 153 if (!HasArcs(s)) 154 Expand(s); 155 return CacheImpl<A>::NumInputEpsilons(s); 156 } 157 158 size_t NumOutputEpsilons(StateId s) { 159 if (!HasArcs(s)) 160 Expand(s); 161 return CacheImpl<A>::NumOutputEpsilons(s); 162 } 163 164 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 165 if (!HasArcs(s)) 166 Expand(s); 167 CacheImpl<A>::InitArcIterator(s, data); 168 } 169 170 // Access to flags encoding compose options/optimizations etc. (for 171 // debugging). 172 virtual uint64 ComposeFlags() const = 0; 173 174 protected: 175 virtual StateId ComputeStart() = 0; 176 virtual Weight ComputeFinal(StateId s) = 0; 177 178 const Fst<A> *fst1_; // first input Fst 179 const Fst<A> *fst2_; // second input Fst 180 }; 181 182 183 // The following class encapsulates implementation-dependent details 184 // of state tuple lookup, i.e. a bijective mapping from triples of two 185 // FST states and an epsilon filter state to the corresponding state 186 // IDs of the fst resulting from composition. The mapping must 187 // implement the [] operator in the style of STL associative 188 // containers (map, hash_map), i.e. table[x] must return a reference 189 // to the value associated with x. If x is an unassigned tuple, the 190 // operator must automatically associate x with value 0. 191 // 192 // NB: "table[x] == 0" for unassigned tuples x is required by the 193 // following off-by-one device used in the implementation of 194 // ComposeFstImpl. The value stored in the table is equal to tuple ID 195 // plus one, i.e. it is always a strictly positive number. Therefore, 196 // table[x] is equal to 0 if and only if x is an unassigned tuple (in 197 // which the algorithm assigns a new ID to x, and sets table[x] - 198 // stored in a reference - to "new ID + 1"). This form of lookup is 199 // more efficient than calling "find(x)" and "insert(make_pair(x, new 200 // ID))" if x is an unassigned tuple. 201 // 202 // The generic implementation is a wrapper around a hash_map. 203 template <class A, uint64 T> 204 class ComposeStateTable { 205 public: 206 typedef typename A::StateId StateId; 207 208 struct StateTuple { 209 StateTuple() {} 210 StateTuple(StateId s1, StateId s2, int f) 211 : state_id1(s1), state_id2(s2), filt(f) {} 212 StateId state_id1; // state Id on fst1 213 StateId state_id2; // state Id on fst2 214 int filt; // epsilon filter state 215 }; 216 217 ComposeStateTable() { 218 StateTuple empty_tuple(kNoStateId, kNoStateId, 0); 219 } 220 221 // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is 222 // inserted into 'table_' (standard STL container semantics). Since 223 // StateId is a built-in type, the explicit default constructor call 224 // StateId() returns 0. 225 StateId &operator[](const StateTuple &tuple) { 226 return table_[tuple]; 227 } 228 229 private: 230 // Comparison object for hashing StateTuple(s). 231 class StateTupleEqual { 232 public: 233 bool operator()(const StateTuple& x, const StateTuple& y) const { 234 return x.state_id1 == y.state_id1 && 235 x.state_id2 == y.state_id2 && 236 x.filt == y.filt; 237 } 238 }; 239 240 static const int kPrime0 = 7853; 241 static const int kPrime1 = 7867; 242 243 // Hash function for StateTuple to Fst states. 244 class StateTupleKey { 245 public: 246 size_t operator()(const StateTuple& x) const { 247 return static_cast<size_t>(x.state_id1 + 248 x.state_id2 * kPrime0 + 249 x.filt * kPrime1); 250 } 251 }; 252 253 // Lookup table mapping state tuples to state IDs. 254 typedef std::unordered_map<StateTuple, StateId, StateTupleKey, 255 StateTupleEqual> StateTable; 256 // Actual table data. 257 StateTable table_; 258 259 DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable); 260 }; 261 262 263 // State tuple lookup table for the composition of a string FST with a 264 // deterministic FST. The class maps state tuples to their unique IDs 265 // (i.e. states of the ComposeFst). Main optimization: due to the 266 // 1-to-1 correspondence between the states of the input string FST 267 // and those of the resulting (string) FST, a state tuple (s1, s2) is 268 // simply mapped to StateId s1. Hence, we use an STL vector as a 269 // lookup table. Template argument Fst1IsString specifies which FST is 270 // a string (this determines whether or not we index the lookup table 271 // by the first or by the second state). 272 template <class A, bool Fst1IsString> 273 class StringDetComposeStateTable { 274 public: 275 typedef typename A::StateId StateId; 276 277 struct StateTuple { 278 typedef typename A::StateId StateId; 279 StateTuple() {} 280 StateTuple(StateId s1, StateId s2, int /* f */) 281 : state_id1(s1), state_id2(s2) {} 282 StateId state_id1; // state Id on fst1 283 StateId state_id2; // state Id on fst2 284 static const int filt = 0; // 'fake' epsilon filter - only needed 285 // for API compatibility 286 }; 287 288 StringDetComposeStateTable() {} 289 290 // Subscript operator. Behaves in a way similar to its map/hash_map 291 // counterpart, i.e. returns a reference to the value associated 292 // with 'tuple', inserting a 0 value if 'tuple' is unassigned. 293 StateId &operator[](const StateTuple &tuple) { 294 StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2; 295 if (index >= (StateId)data_.size()) { 296 // NB: all values in [old_size; index] are initialized to 0. 297 data_.resize(index + 1); 298 } 299 return data_[index]; 300 } 301 302 private: 303 vector<StateId> data_; 304 305 DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable); 306 }; 307 308 309 // Specializations of ComposeStateTable for the string/det case. 310 // Both inherit from StringDetComposeStateTable. 311 template <class A> 312 class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET> 313 : public StringDetComposeStateTable<A, true> { }; 314 315 template <class A> 316 class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET> 317 : public StringDetComposeStateTable<A, false> { }; 318 319 320 // Parameterized implementation of FST composition for a pair of FSTs 321 // matching the property bit vector T. If possible, 322 // instantiation-specific switches in the code are based on the values 323 // of the bits in T, which are known at compile time, so unused code 324 // should be optimized away by the compiler. 325 template <class A, uint64 T> 326 class ComposeFstImpl : public ComposeFstImplBase<A> { 327 typedef typename A::StateId StateId; 328 typedef typename A::Label Label; 329 typedef typename A::Weight Weight; 330 using FstImpl<A>::SetType; 331 using FstImpl<A>::SetProperties; 332 333 enum FindType { FIND_INPUT = 1, // find input label on fst2 334 FIND_OUTPUT = 2, // find output label on fst1 335 FIND_BOTH = 3 }; // find choice state dependent 336 337 typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable; 338 typedef typename StateTupleTable::StateTuple StateTuple; 339 340 public: 341 ComposeFstImpl(const Fst<A> &fst1, 342 const Fst<A> &fst2, 343 const CacheOptions &opts) 344 :ComposeFstImplBase<A>(fst1, fst2, opts) { 345 346 bool osorted = fst1.Properties(kOLabelSorted, false); 347 bool isorted = fst2.Properties(kILabelSorted, false); 348 349 switch (T & COMPOSE_SPECIAL_SYMBOLS) { 350 case COMPOSE_FST1_RHO: 351 case COMPOSE_FST1_PHI: 352 case COMPOSE_FST1_SIGMA: 353 if (!osorted || FLAGS_fst_verify_properties) 354 osorted = fst1.Properties(kOLabelSorted, true); 355 if (!osorted) 356 LOG(FATAL) << "ComposeFst: 1st argument not output label " 357 << "sorted (special symbols present)"; 358 break; 359 case COMPOSE_FST2_RHO: 360 case COMPOSE_FST2_PHI: 361 case COMPOSE_FST2_SIGMA: 362 if (!isorted || FLAGS_fst_verify_properties) 363 isorted = fst2.Properties(kILabelSorted, true); 364 if (!isorted) 365 LOG(FATAL) << "ComposeFst: 2nd argument not input label " 366 << "sorted (special symbols present)"; 367 break; 368 case 0: 369 if ((!isorted && !osorted) || FLAGS_fst_verify_properties) { 370 osorted = fst1.Properties(kOLabelSorted, true); 371 if (!osorted) 372 isorted = fst2.Properties(kILabelSorted, true); 373 } 374 break; 375 default: 376 LOG(FATAL) 377 << "ComposeFst: More than one special symbol used in composition"; 378 } 379 380 if (isorted && (T & COMPOSE_FST2_SIGMA)) { 381 find_type_ = FIND_INPUT; 382 } else if (osorted && (T & COMPOSE_FST1_SIGMA)) { 383 find_type_ = FIND_OUTPUT; 384 } else if (isorted && (T & COMPOSE_FST2_PHI)) { 385 find_type_ = FIND_INPUT; 386 } else if (osorted && (T & COMPOSE_FST1_PHI)) { 387 find_type_ = FIND_OUTPUT; 388 } else if (isorted && (T & COMPOSE_FST2_RHO)) { 389 find_type_ = FIND_INPUT; 390 } else if (osorted && (T & COMPOSE_FST1_RHO)) { 391 find_type_ = FIND_OUTPUT; 392 } else if (isorted && (T & COMPOSE_FST1_STRING)) { 393 find_type_ = FIND_INPUT; 394 } else if(osorted && (T & COMPOSE_FST2_STRING)) { 395 find_type_ = FIND_OUTPUT; 396 } else if (isorted && osorted) { 397 find_type_ = FIND_BOTH; 398 } else if (isorted) { 399 find_type_ = FIND_INPUT; 400 } else if (osorted) { 401 find_type_ = FIND_OUTPUT; 402 } else { 403 LOG(FATAL) << "ComposeFst: 1st argument not output label sorted " 404 << "and 2nd argument is not input label sorted"; 405 } 406 } 407 408 // Finds/creates an Fst state given a StateTuple. Only creates a new 409 // state if StateTuple is not found in the state hash. 410 // 411 // The method exploits the following device: all pairs stored in the 412 // associative container state_tuple_table_ are of the form (tuple, 413 // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has 414 // been stored previously. For unassigned tuples, the call to 415 // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a 416 // result, state_tuple_table_[tuple] == 0 iff tuple is new. 417 StateId FindState(const StateTuple& tuple) { 418 StateId &assoc_value = state_tuple_table_[tuple]; 419 if (assoc_value == 0) { // tuple wasn't present in lookup table: 420 // assign it a new ID. 421 state_tuples_.push_back(tuple); 422 assoc_value = state_tuples_.size(); 423 } 424 return assoc_value - 1; // NB: assoc_value = ID + 1 425 } 426 427 // Generates arc for composition state s from matched input Fst arcs. 428 void AddArc(StateId s, const A &arca, const A &arcb, int f, 429 bool find_input) { 430 A arc; 431 if (find_input) { 432 arc.ilabel = arcb.ilabel; 433 arc.olabel = arca.olabel; 434 arc.weight = Times(arcb.weight, arca.weight); 435 StateTuple tuple(arcb.nextstate, arca.nextstate, f); 436 arc.nextstate = FindState(tuple); 437 } else { 438 arc.ilabel = arca.ilabel; 439 arc.olabel = arcb.olabel; 440 arc.weight = Times(arca.weight, arcb.weight); 441 StateTuple tuple(arca.nextstate, arcb.nextstate, f); 442 arc.nextstate = FindState(tuple); 443 } 444 CacheImpl<A>::AddArc(s, arc); 445 } 446 447 // Arranges it so that the first arg to OrderedExpand is the Fst 448 // that will be passed to FindLabel. 449 void Expand(StateId s) { 450 StateTuple &tuple = state_tuples_[s]; 451 StateId s1 = tuple.state_id1; 452 StateId s2 = tuple.state_id2; 453 int f = tuple.filt; 454 if (find_type_ == FIND_INPUT) 455 OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2, 456 ComposeFstImplBase<A>::fst1_, s1, f, true); 457 else 458 OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1, 459 ComposeFstImplBase<A>::fst2_, s2, f, false); 460 } 461 462 // Access to flags encoding compose options/optimizations etc. (for 463 // debugging). 464 virtual uint64 ComposeFlags() const { return T; } 465 466 private: 467 // This does that actual matching of labels in the composition. The 468 // arguments are ordered so FindLabel is called with state SA of 469 // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg 470 // determines whether the input or output label of arcs at SB is 471 // the one to match on. 472 void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa, 473 const Fst<A> *fstb, StateId sb, int f, bool find_input) { 474 475 size_t numarcsa = fsta->NumArcs(sa); 476 size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) : 477 fsta->NumOutputEpsilons(sa); 478 bool finala = fsta->Final(sa) != Weight::Zero(); 479 ArcIterator< Fst<A> > aitera(*fsta, sa); 480 // First handle special epsilons and sigmas on FSTA 481 for (; !aitera.Done(); aitera.Next()) { 482 const A &arca = aitera.Value(); 483 Label match_labela = find_input ? arca.ilabel : arca.olabel; 484 if (match_labela > 0) { 485 break; 486 } 487 if ((T & COMPOSE_SIGMA) != 0 && match_labela == kSigmaLabel) { 488 // Found a sigma? Match it against all (non-special) symbols 489 // on side b. 490 for (ArcIterator< Fst<A> > aiterb(*fstb, sb); 491 !aiterb.Done(); 492 aiterb.Next()) { 493 const A &arcb = aiterb.Value(); 494 Label labelb = find_input ? arcb.olabel : arcb.ilabel; 495 if (labelb <= 0) continue; 496 AddArc(s, arca, arcb, 0, find_input); 497 } 498 } else if (f == 0 && match_labela == 0) { 499 A earcb(0, 0, Weight::One(), sb); 500 AddArc(s, arca, earcb, 0, find_input); // move forward on epsilon 501 } 502 } 503 // Next handle non-epsilon matches, rho labels, and epsilons on FSTB 504 for (ArcIterator< Fst<A> > aiterb(*fstb, sb); 505 !aiterb.Done(); 506 aiterb.Next()) { 507 const A &arcb = aiterb.Value(); 508 Label match_labelb = find_input ? arcb.olabel : arcb.ilabel; 509 if (match_labelb) { // Consider non-epsilon match 510 if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) { 511 for (; !aitera.Done(); aitera.Next()) { 512 const A &arca = aitera.Value(); 513 Label match_labela = find_input ? arca.ilabel : arca.olabel; 514 if (match_labela != match_labelb) 515 break; 516 AddArc(s, arca, arcb, 0, find_input); // move forward on match 517 } 518 } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) { 519 // If there is no transition labelled 'match_labelb' in 520 // fsta, try matching 'match_labelb' against special symbols 521 // (Phi, Rho,...). 522 for (aitera.Reset(); !aitera.Done(); aitera.Next()) { 523 A arca = aitera.Value(); 524 Label labela = find_input ? arca.ilabel : arca.olabel; 525 if (labela >= 0) { 526 break; 527 } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) { 528 // Case 1: if a failure transition exists, follow its 529 // transitive closure until a) a transition labelled 530 // 'match_labelb' is found, or b) the initial state of 531 // fsta is reached. 532 533 StateId sf = sa; // Start of current failure transition. 534 while (labela == kPhiLabel && sf != arca.nextstate) { 535 sf = arca.nextstate; 536 537 size_t numarcsf = fsta->NumArcs(sf); 538 ArcIterator< Fst<A> > aiterf(*fsta, sf); 539 if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) { 540 // Sub-case 1a: there exists a transition starting 541 // in sf and consuming symbol 'match_labelb'. 542 AddArc(s, aiterf.Value(), arcb, 0, find_input); 543 break; 544 } else { 545 // No transition labelled 'match_labelb' found: try 546 // next failure transition (starting at 'sf'). 547 for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) { 548 arca = aiterf.Value(); 549 labela = find_input ? arca.ilabel : arca.olabel; 550 if (labela >= kPhiLabel) break; 551 } 552 } 553 } 554 if (labela == kPhiLabel && sf == arca.nextstate) { 555 // Sub-case 1b: failure transitions lead to start 556 // state without finding a matching 557 // transition. Therefore, we generate a loop in start 558 // state of fsta. 559 A loop(match_labelb, match_labelb, Weight::One(), sf); 560 AddArc(s, loop, arcb, 0, find_input); 561 } 562 } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) { 563 // Case 2: 'match_labelb' can be matched against a 564 // "rest" (rho) label in fsta. 565 if (find_input) { 566 arca.ilabel = match_labelb; 567 if (arca.olabel == kRhoLabel) 568 arca.olabel = match_labelb; 569 } else { 570 arca.olabel = match_labelb; 571 if (arca.ilabel == kRhoLabel) 572 arca.ilabel = match_labelb; 573 } 574 AddArc(s, arca, arcb, 0, find_input); // move fwd on match 575 } 576 } 577 } 578 } else if (numepsa != numarcsa || finala) { // Handle FSTB epsilon 579 A earca(0, 0, Weight::One(), sa); 580 AddArc(s, earca, arcb, numepsa > 0, find_input); // move on epsilon 581 } 582 } 583 this->SetArcs(s); 584 } 585 586 587 // Finds matches to MATCH_LABEL in arcs given by AITER 588 // using FIND_INPUT to determine whether to look on input or output. 589 bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs, 590 Label match_label, bool find_input) { 591 // binary search for match 592 size_t low = 0; 593 size_t high = numarcs; 594 while (low < high) { 595 size_t mid = (low + high) / 2; 596 aiter->Seek(mid); 597 Label label = find_input ? 598 aiter->Value().ilabel : aiter->Value().olabel; 599 if (label > match_label) { 600 high = mid; 601 } else if (label < match_label) { 602 low = mid + 1; 603 } else { 604 // find first matching label (when non-determinism) 605 for (size_t i = mid; i > low; --i) { 606 aiter->Seek(i - 1); 607 label = find_input ? aiter->Value().ilabel : aiter->Value().olabel; 608 if (label != match_label) { 609 aiter->Seek(i); 610 return true; 611 } 612 } 613 return true; 614 } 615 } 616 return false; 617 } 618 619 StateId ComputeStart() { 620 StateId s1 = ComposeFstImplBase<A>::fst1_->Start(); 621 StateId s2 = ComposeFstImplBase<A>::fst2_->Start(); 622 if (s1 == kNoStateId || s2 == kNoStateId) 623 return kNoStateId; 624 StateTuple tuple(s1, s2, 0); 625 return FindState(tuple); 626 } 627 628 Weight ComputeFinal(StateId s) { 629 StateTuple &tuple = state_tuples_[s]; 630 Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1), 631 ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2)); 632 return final; 633 } 634 635 636 FindType find_type_; // find label on which side? 637 638 // Maps from StateId to StateTuple. 639 vector<StateTuple> state_tuples_; 640 641 // Maps from StateTuple to StateId. 642 StateTupleTable state_tuple_table_; 643 644 DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl); 645 }; 646 647 648 // Computes the composition of two transducers. This version is a 649 // delayed Fst. If FST1 transduces string x to y with weight a and FST2 650 // transduces y to z with weight b, then their composition transduces 651 // string x to z with weight Times(x, z). 652 // 653 // The output labels of the first transducer or the input labels of 654 // the second transducer must be sorted. The weights need to form a 655 // commutative semiring (valid for TropicalWeight and LogWeight). 656 // 657 // Complexity: 658 // Assuming the first FST is unsorted and the second is sorted: 659 // - Time: O(v1 v2 d1 (log d2 + m2)), 660 // - Space: O(v1 v2) 661 // where vi = # of states visited, di = maximum out-degree, and mi the 662 // maximum multiplicity of the states visited for the ith 663 // FST. Constant time and space to visit an input state or arc is 664 // assumed and exclusive of caching. 665 // 666 // Caveats: 667 // - ComposeFst does not trim its output (since it is a delayed operation). 668 // - The efficiency of composition can be strongly affected by several factors: 669 // - the choice of which tnansducer is sorted - prefer sorting the FST 670 // that has the greater average out-degree. 671 // - the amount of non-determinism 672 // - the presence and location of epsilon transitions - avoid epsilon 673 // transitions on the output side of the first transducer or 674 // the input side of the second transducer or prefer placing 675 // them later in a path since they delay matching and can 676 // introduce non-coaccessible states and transitions. 677 template <class A> 678 class ComposeFst : public Fst<A> { 679 public: 680 friend class ArcIterator< ComposeFst<A> >; 681 friend class CacheStateIterator< ComposeFst<A> >; 682 friend class CacheArcIterator< ComposeFst<A> >; 683 684 typedef A Arc; 685 typedef typename A::Weight Weight; 686 typedef typename A::StateId StateId; 687 typedef CacheState<A> State; 688 689 ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2) 690 : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { } 691 692 template <uint64 T> 693 ComposeFst(const Fst<A> &fst1, 694 const Fst<A> &fst2, 695 const ComposeFstOptions<T> &opts) 696 : impl_(Init(fst1, fst2, opts)) { } 697 698 ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) { 699 impl_->IncrRefCount(); 700 } 701 702 virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_; } 703 704 virtual StateId Start() const { return impl_->Start(); } 705 706 virtual Weight Final(StateId s) const { return impl_->Final(s); } 707 708 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 709 710 virtual size_t NumInputEpsilons(StateId s) const { 711 return impl_->NumInputEpsilons(s); 712 } 713 714 virtual size_t NumOutputEpsilons(StateId s) const { 715 return impl_->NumOutputEpsilons(s); 716 } 717 718 virtual uint64 Properties(uint64 mask, bool test) const { 719 if (test) { 720 uint64 known, test = TestProperties(*this, mask, &known); 721 impl_->SetProperties(test, known); 722 return test & mask; 723 } else { 724 return impl_->Properties(mask); 725 } 726 } 727 728 virtual const string& Type() const { return impl_->Type(); } 729 730 virtual ComposeFst<A> *Copy() const { 731 return new ComposeFst<A>(*this); 732 } 733 734 virtual const SymbolTable* InputSymbols() const { 735 return impl_->InputSymbols(); 736 } 737 738 virtual const SymbolTable* OutputSymbols() const { 739 return impl_->OutputSymbols(); 740 } 741 742 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 743 744 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 745 impl_->InitArcIterator(s, data); 746 } 747 748 // Access to flags encoding compose options/optimizations etc. (for 749 // debugging). 750 uint64 ComposeFlags() const { return impl_->ComposeFlags(); } 751 752 protected: 753 ComposeFstImplBase<A> *Impl() { return impl_; } 754 755 private: 756 ComposeFstImplBase<A> *impl_; 757 758 // Auxiliary method encapsulating the creation of a ComposeFst 759 // implementation that is appropriate for the properties of fst1 and 760 // fst2. 761 template <uint64 T> 762 static ComposeFstImplBase<A> *Init( 763 const Fst<A> &fst1, 764 const Fst<A> &fst2, 765 const ComposeFstOptions<T> &opts) { 766 767 // Filter for sort properties (forces a property check). 768 uint64 sort_props_mask = kILabelSorted | kOLabelSorted; 769 // Filter for optimization-related properties (does not force a 770 // property-check). 771 uint64 opt_props_mask = 772 kString | kIDeterministic | kODeterministic | kNoIEpsilons | 773 kNoOEpsilons; 774 775 uint64 props1 = fst1.Properties(sort_props_mask, true); 776 uint64 props2 = fst2.Properties(sort_props_mask, true); 777 778 props1 |= fst1.Properties(opt_props_mask, false); 779 props2 |= fst2.Properties(opt_props_mask, false); 780 781 if (!(Weight::Properties() & kCommutative)) { 782 props1 |= fst1.Properties(kUnweighted, true); 783 props2 |= fst2.Properties(kUnweighted, true); 784 if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) 785 LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: " 786 << Weight::Type(); 787 } 788 789 // Case 1: flag COMPOSE_GENERIC disables optimizations. 790 if (T & COMPOSE_GENERIC) { 791 return new ComposeFstImpl<A, T>(fst1, fst2, opts); 792 } 793 794 const uint64 kStringDetOptProps = 795 kIDeterministic | kILabelSorted | kNoIEpsilons; 796 const uint64 kDetStringOptProps = 797 kODeterministic | kOLabelSorted | kNoOEpsilons; 798 799 // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free. 800 if ((props1 & kString) && 801 !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && 802 ((props2 & kStringDetOptProps) == kStringDetOptProps)) { 803 return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>( 804 fst1, fst2, opts); 805 } 806 // Case 3: fst1 is deterministic and epsilon-free, fst2 is string. 807 if ((props2 & kString) && 808 !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && 809 ((props1 & kDetStringOptProps) == kDetStringOptProps)) { 810 return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>( 811 fst1, fst2, opts); 812 } 813 814 // Default case: no optimizations. 815 return new ComposeFstImpl<A, T>(fst1, fst2, opts); 816 } 817 818 void operator=(const ComposeFst<A> &fst); // disallow 819 }; 820 821 822 // Specialization for ComposeFst. 823 template<class A> 824 class StateIterator< ComposeFst<A> > 825 : public CacheStateIterator< ComposeFst<A> > { 826 public: 827 explicit StateIterator(const ComposeFst<A> &fst) 828 : CacheStateIterator< ComposeFst<A> >(fst) {} 829 }; 830 831 832 // Specialization for ComposeFst. 833 template <class A> 834 class ArcIterator< ComposeFst<A> > 835 : public CacheArcIterator< ComposeFst<A> > { 836 public: 837 typedef typename A::StateId StateId; 838 839 ArcIterator(const ComposeFst<A> &fst, StateId s) 840 : CacheArcIterator< ComposeFst<A> >(fst, s) { 841 if (!fst.impl_->HasArcs(s)) 842 fst.impl_->Expand(s); 843 } 844 845 private: 846 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 847 }; 848 849 template <class A> inline 850 void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 851 data->base = new StateIterator< ComposeFst<A> >(*this); 852 } 853 854 // Useful alias when using StdArc. 855 typedef ComposeFst<StdArc> StdComposeFst; 856 857 858 struct ComposeOptions { 859 bool connect; // Connect output 860 861 ComposeOptions(bool c) : connect(c) {} 862 ComposeOptions() : connect(true) { } 863 }; 864 865 866 // Computes the composition of two transducers. This version writes 867 // the composed FST into a MurableFst. If FST1 transduces string x to 868 // y with weight a and FST2 transduces y to z with weight b, then 869 // their composition transduces string x to z with weight 870 // Times(x, z). 871 // 872 // The output labels of the first transducer or the input labels of 873 // the second transducer must be sorted. The weights need to form a 874 // commutative semiring (valid for TropicalWeight and LogWeight). 875 // 876 // Complexity: 877 // Assuming the first FST is unsorted and the second is sorted: 878 // - Time: O(V1 V2 D1 (log D2 + M2)), 879 // - Space: O(V1 V2 D1 M2) 880 // where Vi = # of states, Di = maximum out-degree, and Mi is 881 // the maximum multiplicity for the ith FST. 882 // 883 // Caveats: 884 // - Compose trims its output. 885 // - The efficiency of composition can be strongly affected by several factors: 886 // - the choice of which tnansducer is sorted - prefer sorting the FST 887 // that has the greater average out-degree. 888 // - the amount of non-determinism 889 // - the presence and location of epsilon transitions - avoid epsilon 890 // transitions on the output side of the first transducer or 891 // the input side of the second transducer or prefer placing 892 // them later in a path since they delay matching and can 893 // introduce non-coaccessible states and transitions. 894 template<class Arc> 895 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, 896 MutableFst<Arc> *ofst, 897 const ComposeOptions &opts = ComposeOptions()) { 898 ComposeFstOptions<> nopts; 899 nopts.gc_limit = 0; // Cache only the last state for fastest copy. 900 *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts); 901 if (opts.connect) 902 Connect(ofst); 903 } 904 905 } // namespace fst 906 907 #endif // FST_LIB_COMPOSE_H__ 908