1 // cache.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // An Fst implementation that caches FST elements of a delayed 20 // computation. 21 22 #ifndef FST_LIB_CACHE_H__ 23 #define FST_LIB_CACHE_H__ 24 25 #include <vector> 26 using std::vector; 27 #include <list> 28 29 #include <fst/vector-fst.h> 30 31 32 DECLARE_bool(fst_default_cache_gc); 33 DECLARE_int64(fst_default_cache_gc_limit); 34 35 namespace fst { 36 37 struct CacheOptions { 38 bool gc; // enable GC 39 size_t gc_limit; // # of bytes allowed before GC 40 41 CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {} 42 CacheOptions() 43 : gc(FLAGS_fst_default_cache_gc), 44 gc_limit(FLAGS_fst_default_cache_gc_limit) {} 45 }; 46 47 // A CacheStateAllocator allocates and frees CacheStates 48 // template <class S> 49 // struct CacheStateAllocator { 50 // S *Allocate(StateId s); 51 // void Free(S *state, StateId s); 52 // }; 53 // 54 55 // A simple allocator class, can be overridden as needed, 56 // maintains a single entry cache. 57 template <class S> 58 struct DefaultCacheStateAllocator { 59 typedef typename S::Arc::StateId StateId; 60 61 DefaultCacheStateAllocator() : mru_(NULL) { } 62 63 ~DefaultCacheStateAllocator() { 64 delete mru_; 65 } 66 67 S *Allocate(StateId s) { 68 if (mru_) { 69 S *state = mru_; 70 mru_ = NULL; 71 state->Reset(); 72 return state; 73 } 74 return new S(); 75 } 76 77 void Free(S *state, StateId s) { 78 if (mru_) { 79 delete mru_; 80 } 81 mru_ = state; 82 } 83 84 private: 85 S *mru_; 86 }; 87 88 // VectorState but additionally has a flags data member (see 89 // CacheState below). This class is used to cache FST elements with 90 // the flags used to indicate what has been cached. Use HasStart() 91 // HasFinal(), and HasArcs() to determine if cached and SetStart(), 92 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note 93 // you must set the final weight even if the state is non-final to 94 // mark it as cached. If the 'gc' option is 'false', cached items have 95 // the extent of the FST - minimizing computation. If the 'gc' option 96 // is 'true', garbage collection of states (not in use in an arc 97 // iterator and not 'protected') is performed, in a rough 98 // approximation of LRU order, when 'gc_limit' bytes is reached - 99 // controlling memory use. When 'gc_limit' is 0, special optimizations 100 // apply - minimizing memory use. 101 102 template <class S, class C = DefaultCacheStateAllocator<S> > 103 class CacheBaseImpl : public VectorFstBaseImpl<S> { 104 public: 105 typedef S State; 106 typedef C Allocator; 107 typedef typename State::Arc Arc; 108 typedef typename Arc::Weight Weight; 109 typedef typename Arc::StateId StateId; 110 111 using FstImpl<Arc>::Type; 112 using FstImpl<Arc>::Properties; 113 using FstImpl<Arc>::SetProperties; 114 using VectorFstBaseImpl<State>::NumStates; 115 using VectorFstBaseImpl<State>::Start; 116 using VectorFstBaseImpl<State>::AddState; 117 using VectorFstBaseImpl<State>::SetState; 118 using VectorFstBaseImpl<State>::ReserveStates; 119 120 explicit CacheBaseImpl(C *allocator = 0) 121 : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), 122 cache_first_state_id_(kNoStateId), cache_first_state_(0), 123 cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0), 124 cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit || 125 FLAGS_fst_default_cache_gc_limit == 0 ? 126 FLAGS_fst_default_cache_gc_limit : kMinCacheLimit), 127 protect_(false) { 128 allocator_ = allocator ? allocator : new C(); 129 } 130 131 explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0) 132 : cache_start_(false), nknown_states_(0), 133 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), 134 cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0), 135 cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ? 136 opts.gc_limit : kMinCacheLimit), 137 protect_(false) { 138 allocator_ = allocator ? allocator : new C(); 139 } 140 141 // Preserve gc parameters. If preserve_cache true, also preserves 142 // cache data. 143 CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false) 144 : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0), 145 min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), 146 cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0), 147 cache_limit_(impl.cache_limit_), 148 protect_(impl.protect_) { 149 allocator_ = new C(); 150 if (preserve_cache) { 151 cache_start_ = impl.cache_start_; 152 nknown_states_ = impl.nknown_states_; 153 expanded_states_ = impl.expanded_states_; 154 min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; 155 if (impl.cache_first_state_id_ != kNoStateId) { 156 cache_first_state_id_ = impl.cache_first_state_id_; 157 cache_first_state_ = allocator_->Allocate(cache_first_state_id_); 158 *cache_first_state_ = *impl.cache_first_state_; 159 } 160 cache_states_ = impl.cache_states_; 161 cache_size_ = impl.cache_size_; 162 ReserveStates(impl.NumStates()); 163 for (StateId s = 0; s < impl.NumStates(); ++s) { 164 const S *state = 165 static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s); 166 if (state) { 167 S *copied_state = allocator_->Allocate(s); 168 *copied_state = *state; 169 AddState(copied_state); 170 } else { 171 AddState(0); 172 } 173 } 174 VectorFstBaseImpl<S>::SetStart(impl.Start()); 175 } 176 } 177 178 ~CacheBaseImpl() { 179 allocator_->Free(cache_first_state_, cache_first_state_id_); 180 delete allocator_; 181 } 182 183 // Gets a state from its ID; state must exist. 184 const S *GetState(StateId s) const { 185 if (s == cache_first_state_id_) 186 return cache_first_state_; 187 else 188 return VectorFstBaseImpl<S>::GetState(s); 189 } 190 191 // Gets a state from its ID; state must exist. 192 S *GetState(StateId s) { 193 if (s == cache_first_state_id_) 194 return cache_first_state_; 195 else 196 return VectorFstBaseImpl<S>::GetState(s); 197 } 198 199 // Gets a state from its ID; return 0 if it doesn't exist. 200 const S *CheckState(StateId s) const { 201 if (s == cache_first_state_id_) 202 return cache_first_state_; 203 else if (s < NumStates()) 204 return VectorFstBaseImpl<S>::GetState(s); 205 else 206 return 0; 207 } 208 209 // Gets a state from its ID; add it if necessary. 210 S *ExtendState(StateId s); 211 212 void SetStart(StateId s) { 213 VectorFstBaseImpl<S>::SetStart(s); 214 cache_start_ = true; 215 if (s >= nknown_states_) 216 nknown_states_ = s + 1; 217 } 218 219 void SetFinal(StateId s, Weight w) { 220 S *state = ExtendState(s); 221 state->final = w; 222 state->flags |= kCacheFinal | kCacheRecent | kCacheModified; 223 } 224 225 // AddArc adds a single arc to state s and does incremental cache 226 // book-keeping. For efficiency, prefer PushArc and SetArcs below 227 // when possible. 228 void AddArc(StateId s, const Arc &arc) { 229 S *state = ExtendState(s); 230 state->arcs.push_back(arc); 231 if (arc.ilabel == 0) { 232 ++state->niepsilons; 233 } 234 if (arc.olabel == 0) { 235 ++state->noepsilons; 236 } 237 const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); 238 SetProperties(AddArcProperties(Properties(), s, arc, parc)); 239 state->flags |= kCacheModified; 240 if (cache_gc_ && s != cache_first_state_id_ && 241 !(state->flags & kCacheProtect)) { 242 cache_size_ += sizeof(Arc); 243 if (cache_size_ > cache_limit_) 244 GC(s, false); 245 } 246 } 247 248 // Adds a single arc to state s but delays cache book-keeping. 249 // SetArcs must be called when all PushArc calls at a state are 250 // complete. Do not mix with calls to AddArc. 251 void PushArc(StateId s, const Arc &arc) { 252 S *state = ExtendState(s); 253 state->arcs.push_back(arc); 254 } 255 256 // Marks arcs of state s as cached and does cache book-keeping after all 257 // calls to PushArc have been completed. Do not mix with calls to AddArc. 258 void SetArcs(StateId s) { 259 S *state = ExtendState(s); 260 vector<Arc> &arcs = state->arcs; 261 state->niepsilons = state->noepsilons = 0; 262 for (size_t a = 0; a < arcs.size(); ++a) { 263 const Arc &arc = arcs[a]; 264 if (arc.nextstate >= nknown_states_) 265 nknown_states_ = arc.nextstate + 1; 266 if (arc.ilabel == 0) 267 ++state->niepsilons; 268 if (arc.olabel == 0) 269 ++state->noepsilons; 270 } 271 ExpandedState(s); 272 state->flags |= kCacheArcs | kCacheRecent | kCacheModified; 273 if (cache_gc_ && s != cache_first_state_id_ && 274 !(state->flags & kCacheProtect)) { 275 cache_size_ += arcs.capacity() * sizeof(Arc); 276 if (cache_size_ > cache_limit_) 277 GC(s, false); 278 } 279 }; 280 281 void ReserveArcs(StateId s, size_t n) { 282 S *state = ExtendState(s); 283 state->arcs.reserve(n); 284 } 285 286 void DeleteArcs(StateId s, size_t n) { 287 S *state = ExtendState(s); 288 const vector<Arc> &arcs = state->arcs; 289 for (size_t i = 0; i < n; ++i) { 290 size_t j = arcs.size() - i - 1; 291 if (arcs[j].ilabel == 0) 292 --state->niepsilons; 293 if (arcs[j].olabel == 0) 294 --state->noepsilons; 295 } 296 297 state->arcs.resize(arcs.size() - n); 298 SetProperties(DeleteArcsProperties(Properties())); 299 state->flags |= kCacheModified; 300 if (cache_gc_ && s != cache_first_state_id_ && 301 !(state->flags & kCacheProtect)) { 302 cache_size_ -= n * sizeof(Arc); 303 } 304 } 305 306 void DeleteArcs(StateId s) { 307 S *state = ExtendState(s); 308 size_t n = state->arcs.size(); 309 state->niepsilons = 0; 310 state->noepsilons = 0; 311 state->arcs.clear(); 312 SetProperties(DeleteArcsProperties(Properties())); 313 state->flags |= kCacheModified; 314 if (cache_gc_ && s != cache_first_state_id_ && 315 !(state->flags & kCacheProtect)) { 316 cache_size_ -= n * sizeof(Arc); 317 } 318 } 319 320 void DeleteStates(const vector<StateId> &dstates) { 321 size_t old_num_states = NumStates(); 322 vector<StateId> newid(old_num_states, 0); 323 for (size_t i = 0; i < dstates.size(); ++i) 324 newid[dstates[i]] = kNoStateId; 325 StateId nstates = 0; 326 for (StateId s = 0; s < old_num_states; ++s) { 327 if (newid[s] != kNoStateId) { 328 newid[s] = nstates; 329 ++nstates; 330 } 331 } 332 // just for states_.resize(), does unnecessary walk. 333 VectorFstBaseImpl<S>::DeleteStates(dstates); 334 SetProperties(DeleteStatesProperties(Properties())); 335 // Update list of cached states. 336 typename list<StateId>::iterator siter = cache_states_.begin(); 337 while (siter != cache_states_.end()) { 338 if (newid[*siter] != kNoStateId) { 339 *siter = newid[*siter]; 340 ++siter; 341 } else { 342 cache_states_.erase(siter++); 343 } 344 } 345 } 346 347 void DeleteStates() { 348 cache_states_.clear(); 349 allocator_->Free(cache_first_state_, cache_first_state_id_); 350 for (int s = 0; s < NumStates(); ++s) { 351 allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s); 352 SetState(s, 0); 353 } 354 nknown_states_ = 0; 355 min_unexpanded_state_id_ = 0; 356 cache_first_state_id_ = kNoStateId; 357 cache_first_state_ = 0; 358 cache_size_ = 0; 359 cache_start_ = false; 360 VectorFstBaseImpl<State>::DeleteStates(); 361 SetProperties(DeleteAllStatesProperties(Properties(), 362 kExpanded | kMutable)); 363 } 364 365 // Is the start state cached? 366 bool HasStart() const { 367 if (!cache_start_ && Properties(kError)) 368 cache_start_ = true; 369 return cache_start_; 370 } 371 372 // Is the final weight of state s cached? 373 bool HasFinal(StateId s) const { 374 const S *state = CheckState(s); 375 if (state && state->flags & kCacheFinal) { 376 state->flags |= kCacheRecent; 377 return true; 378 } else { 379 return false; 380 } 381 } 382 383 // Are arcs of state s cached? 384 bool HasArcs(StateId s) const { 385 const S *state = CheckState(s); 386 if (state && state->flags & kCacheArcs) { 387 state->flags |= kCacheRecent; 388 return true; 389 } else { 390 return false; 391 } 392 } 393 394 Weight Final(StateId s) const { 395 const S *state = GetState(s); 396 return state->final; 397 } 398 399 size_t NumArcs(StateId s) const { 400 const S *state = GetState(s); 401 return state->arcs.size(); 402 } 403 404 size_t NumInputEpsilons(StateId s) const { 405 const S *state = GetState(s); 406 return state->niepsilons; 407 } 408 409 size_t NumOutputEpsilons(StateId s) const { 410 const S *state = GetState(s); 411 return state->noepsilons; 412 } 413 414 // Provides information needed for generic arc iterator. 415 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { 416 const S *state = GetState(s); 417 data->base = 0; 418 data->narcs = state->arcs.size(); 419 data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0; 420 data->ref_count = &(state->ref_count); 421 ++(*data->ref_count); 422 } 423 424 // Number of known states. 425 StateId NumKnownStates() const { return nknown_states_; } 426 427 // Update number of known states taking in account the existence of state s. 428 void UpdateNumKnownStates(StateId s) { 429 if (s >= nknown_states_) 430 nknown_states_ = s + 1; 431 } 432 433 // Find the mininum never-expanded state Id 434 StateId MinUnexpandedState() const { 435 while (min_unexpanded_state_id_ < expanded_states_.size() && 436 expanded_states_[min_unexpanded_state_id_]) 437 ++min_unexpanded_state_id_; 438 return min_unexpanded_state_id_; 439 } 440 441 // Removes from cache_states_ and uncaches (not referenced-counted 442 // or protected) states that have not been accessed since the last 443 // GC until at most cache_fraction * cache_limit_ bytes are cached. 444 // If that fails to free enough, recurs uncaching recently visited 445 // states as well. If still unable to free enough memory, then 446 // widens cache_limit_ to fulfill condition. 447 void GC(StateId current, bool free_recent, float cache_fraction = 0.666); 448 449 // Setc/clears GC protection: if true, new states are protected 450 // from garbage collection. 451 void GCProtect(bool on) { protect_ = on; } 452 453 void ExpandedState(StateId s) { 454 if (s < min_unexpanded_state_id_) 455 return; 456 while (expanded_states_.size() <= s) 457 expanded_states_.push_back(false); 458 expanded_states_[s] = true; 459 } 460 461 C *GetAllocator() const { 462 return allocator_; 463 } 464 465 // Caching on/off switch, limit and size accessors. 466 bool GetCacheGc() const { return cache_gc_; } 467 size_t GetCacheLimit() const { return cache_limit_; } 468 size_t GetCacheSize() const { return cache_size_; } 469 470 private: 471 static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit 472 473 static const uint32 kCacheFinal = 0x0001; // Final weight has been cached 474 static const uint32 kCacheArcs = 0x0002; // Arcs have been cached 475 static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC 476 static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected 477 478 public: 479 static const uint32 kCacheModified = 0x0010; // Mark state as modified 480 static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent 481 | kCacheProtect | kCacheModified; 482 483 private: 484 C *allocator_; // used to allocate new states 485 mutable bool cache_start_; // Is the start state cached? 486 StateId nknown_states_; // # of known states 487 vector<bool> expanded_states_; // states that have been expanded 488 mutable StateId min_unexpanded_state_id_; // minimum never-expanded state Id 489 StateId cache_first_state_id_; // First cached state id 490 S *cache_first_state_; // First cached state 491 list<StateId> cache_states_; // list of currently cached states 492 bool cache_gc_; // enable GC 493 size_t cache_size_; // # of bytes cached 494 size_t cache_limit_; // # of bytes allowed before GC 495 bool protect_; // Protect new states from GC 496 497 void operator=(const CacheBaseImpl<S, C> &impl); // disallow 498 }; 499 500 // Gets a state from its ID; add it if necessary. 501 template <class S, class C> 502 S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) { 503 // If 'protect_' true and a new state, protects from garbage collection. 504 if (s == cache_first_state_id_) { 505 return cache_first_state_; // Return 1st cached state 506 } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { 507 cache_first_state_id_ = s; // Remember 1st cached state 508 cache_first_state_ = allocator_->Allocate(s); 509 if (protect_) cache_first_state_->flags |= kCacheProtect; 510 return cache_first_state_; 511 } else if (cache_first_state_id_ != kNoStateId && 512 cache_first_state_->ref_count == 0 && 513 !(cache_first_state_->flags & kCacheProtect)) { 514 // With Default allocator, the Free and Allocate will reuse the same S*. 515 allocator_->Free(cache_first_state_, cache_first_state_id_); 516 cache_first_state_id_ = s; 517 cache_first_state_ = allocator_->Allocate(s); 518 if (protect_) cache_first_state_->flags |= kCacheProtect; 519 return cache_first_state_; // Return 1st cached state 520 } else { 521 while (NumStates() <= s) // Add state to main cache 522 AddState(0); 523 S *state = VectorFstBaseImpl<S>::GetState(s); 524 if (!state) { 525 state = allocator_->Allocate(s); 526 if (protect_) state->flags |= kCacheProtect; 527 SetState(s, state); 528 if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state 529 while (NumStates() <= cache_first_state_id_) 530 AddState(0); 531 SetState(cache_first_state_id_, cache_first_state_); 532 if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) { 533 cache_states_.push_back(cache_first_state_id_); 534 cache_size_ += sizeof(S) + 535 cache_first_state_->arcs.capacity() * sizeof(Arc); 536 } 537 cache_limit_ = kMinCacheLimit; 538 cache_first_state_id_ = kNoStateId; 539 cache_first_state_ = 0; 540 } 541 if (cache_gc_ && !protect_) { 542 cache_states_.push_back(s); 543 cache_size_ += sizeof(S); 544 if (cache_size_ > cache_limit_) 545 GC(s, false); 546 } 547 } 548 return state; 549 } 550 } 551 552 // Removes from cache_states_ and uncaches (not referenced-counted or 553 // protected) states that have not been accessed since the last GC 554 // until at most cache_fraction * cache_limit_ bytes are cached. If 555 // that fails to free enough, recurs uncaching recently visited states 556 // as well. If still unable to free enough memory, then widens cache_limit_ 557 // to fulfill condition. 558 template <class S, class C> 559 void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current, 560 bool free_recent, float cache_fraction) { 561 if (!cache_gc_) 562 return; 563 VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this 564 << "), free recently cached = " << free_recent 565 << ", cache size = " << cache_size_ 566 << ", cache frac = " << cache_fraction 567 << ", cache limit = " << cache_limit_ << "\n"; 568 typename list<StateId>::iterator siter = cache_states_.begin(); 569 570 size_t cache_target = cache_fraction * cache_limit_; 571 while (siter != cache_states_.end()) { 572 StateId s = *siter; 573 S* state = VectorFstBaseImpl<S>::GetState(s); 574 if (cache_size_ > cache_target && state->ref_count == 0 && 575 (free_recent || !(state->flags & kCacheRecent)) && s != current) { 576 cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); 577 allocator_->Free(state, s); 578 SetState(s, 0); 579 cache_states_.erase(siter++); 580 } else { 581 state->flags &= ~kCacheRecent; 582 ++siter; 583 } 584 } 585 if (!free_recent && cache_size_ > cache_target) { // recurses on recent 586 GC(current, true); 587 } else if (cache_target > 0) { // widens cache limit 588 while (cache_size_ > cache_target) { 589 cache_limit_ *= 2; 590 cache_target *= 2; 591 } 592 } else if (cache_size_ > 0) { 593 FSTERROR() << "CacheImpl:GC: Unable to free all cached states"; 594 } 595 VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this 596 << "), free recently cached = " << free_recent 597 << ", cache size = " << cache_size_ 598 << ", cache frac = " << cache_fraction 599 << ", cache limit = " << cache_limit_ << "\n"; 600 } 601 602 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal; 603 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs; 604 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent; 605 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified; 606 template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit; 607 608 // Arcs implemented by an STL vector per state. Similar to VectorState 609 // but adds flags and ref count to keep track of what has been cached. 610 template <class A> 611 struct CacheState { 612 typedef A Arc; 613 typedef typename A::Weight Weight; 614 typedef typename A::StateId StateId; 615 616 CacheState() : final(Weight::Zero()), flags(0), ref_count(0) {} 617 618 void Reset() { 619 flags = 0; 620 ref_count = 0; 621 arcs.resize(0); 622 } 623 624 Weight final; // Final weight 625 vector<A> arcs; // Arcs represenation 626 size_t niepsilons; // # of input epsilons 627 size_t noepsilons; // # of output epsilons 628 mutable uint32 flags; 629 mutable int ref_count; 630 }; 631 632 // A CacheBaseImpl with a commonly used CacheState. 633 template <class A> 634 class CacheImpl : public CacheBaseImpl< CacheState<A> > { 635 public: 636 typedef CacheState<A> State; 637 638 CacheImpl() {} 639 640 explicit CacheImpl(const CacheOptions &opts) 641 : CacheBaseImpl< CacheState<A> >(opts) {} 642 643 CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false) 644 : CacheBaseImpl<State>(impl, preserve_cache) {} 645 646 private: 647 void operator=(const CacheImpl<State> &impl); // disallow 648 }; 649 650 651 // Use this to make a state iterator for a CacheBaseImpl-derived Fst, 652 // which must have type 'State' defined. Note this iterator only 653 // returns those states reachable from the initial state, so consider 654 // implementing a class-specific one. 655 template <class F> 656 class CacheStateIterator : public StateIteratorBase<typename F::Arc> { 657 public: 658 typedef typename F::Arc Arc; 659 typedef typename Arc::StateId StateId; 660 typedef typename F::State State; 661 typedef CacheBaseImpl<State> Impl; 662 663 CacheStateIterator(const F &fst, Impl *impl) 664 : fst_(fst), impl_(impl), s_(0) { 665 fst_.Start(); // force start state 666 } 667 668 bool Done() const { 669 if (s_ < impl_->NumKnownStates()) 670 return false; 671 if (s_ < impl_->NumKnownStates()) 672 return false; 673 for (StateId u = impl_->MinUnexpandedState(); 674 u < impl_->NumKnownStates(); 675 u = impl_->MinUnexpandedState()) { 676 // force state expansion 677 ArcIterator<F> aiter(fst_, u); 678 aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); 679 for (; !aiter.Done(); aiter.Next()) 680 impl_->UpdateNumKnownStates(aiter.Value().nextstate); 681 impl_->ExpandedState(u); 682 if (s_ < impl_->NumKnownStates()) 683 return false; 684 } 685 return true; 686 } 687 688 StateId Value() const { return s_; } 689 690 void Next() { ++s_; } 691 692 void Reset() { s_ = 0; } 693 694 private: 695 // This allows base class virtual access to non-virtual derived- 696 // class members of the same name. It makes the derived class more 697 // efficient to use but unsafe to further derive. 698 virtual bool Done_() const { return Done(); } 699 virtual StateId Value_() const { return Value(); } 700 virtual void Next_() { Next(); } 701 virtual void Reset_() { Reset(); } 702 703 const F &fst_; 704 Impl *impl_; 705 StateId s_; 706 }; 707 708 709 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst, 710 // which must have types 'Arc' and 'State' defined. 711 template <class F, 712 class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > > 713 class CacheArcIterator { 714 public: 715 typedef typename F::Arc Arc; 716 typedef typename F::State State; 717 typedef typename Arc::StateId StateId; 718 typedef CacheBaseImpl<State, C> Impl; 719 720 CacheArcIterator(Impl *impl, StateId s) : i_(0) { 721 state_ = impl->ExtendState(s); 722 ++state_->ref_count; 723 } 724 725 ~CacheArcIterator() { --state_->ref_count; } 726 727 bool Done() const { return i_ >= state_->arcs.size(); } 728 729 const Arc& Value() const { return state_->arcs[i_]; } 730 731 void Next() { ++i_; } 732 733 size_t Position() const { return i_; } 734 735 void Reset() { i_ = 0; } 736 737 void Seek(size_t a) { i_ = a; } 738 739 uint32 Flags() const { 740 return kArcValueFlags; 741 } 742 743 void SetFlags(uint32 flags, uint32 mask) {} 744 745 private: 746 const State *state_; 747 size_t i_; 748 749 DISALLOW_COPY_AND_ASSIGN(CacheArcIterator); 750 }; 751 752 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst, 753 // which must have types 'Arc' and 'State' defined. 754 template <class F, 755 class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > > 756 class CacheMutableArcIterator 757 : public MutableArcIteratorBase<typename F::Arc> { 758 public: 759 typedef typename F::State State; 760 typedef typename F::Arc Arc; 761 typedef typename Arc::StateId StateId; 762 typedef typename Arc::Weight Weight; 763 typedef CacheBaseImpl<State, C> Impl; 764 765 // You will need to call MutateCheck() in the constructor. 766 CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { 767 state_ = impl_->ExtendState(s_); 768 ++state_->ref_count; 769 }; 770 771 ~CacheMutableArcIterator() { 772 --state_->ref_count; 773 } 774 775 bool Done() const { return i_ >= state_->arcs.size(); } 776 777 const Arc& Value() const { return state_->arcs[i_]; } 778 779 void Next() { ++i_; } 780 781 size_t Position() const { return i_; } 782 783 void Reset() { i_ = 0; } 784 785 void Seek(size_t a) { i_ = a; } 786 787 void SetValue(const Arc& arc) { 788 state_->flags |= CacheBaseImpl<State, C>::kCacheModified; 789 uint64 properties = impl_->Properties(); 790 Arc& oarc = state_->arcs[i_]; 791 if (oarc.ilabel != oarc.olabel) 792 properties &= ~kNotAcceptor; 793 if (oarc.ilabel == 0) { 794 --state_->niepsilons; 795 properties &= ~kIEpsilons; 796 if (oarc.olabel == 0) 797 properties &= ~kEpsilons; 798 } 799 if (oarc.olabel == 0) { 800 --state_->noepsilons; 801 properties &= ~kOEpsilons; 802 } 803 if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) 804 properties &= ~kWeighted; 805 oarc = arc; 806 if (arc.ilabel != arc.olabel) { 807 properties |= kNotAcceptor; 808 properties &= ~kAcceptor; 809 } 810 if (arc.ilabel == 0) { 811 ++state_->niepsilons; 812 properties |= kIEpsilons; 813 properties &= ~kNoIEpsilons; 814 if (arc.olabel == 0) { 815 properties |= kEpsilons; 816 properties &= ~kNoEpsilons; 817 } 818 } 819 if (arc.olabel == 0) { 820 ++state_->noepsilons; 821 properties |= kOEpsilons; 822 properties &= ~kNoOEpsilons; 823 } 824 if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { 825 properties |= kWeighted; 826 properties &= ~kUnweighted; 827 } 828 properties &= kSetArcProperties | kAcceptor | kNotAcceptor | 829 kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons | 830 kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted; 831 impl_->SetProperties(properties); 832 } 833 834 uint32 Flags() const { 835 return kArcValueFlags; 836 } 837 838 void SetFlags(uint32 f, uint32 m) {} 839 840 private: 841 virtual bool Done_() const { return Done(); } 842 virtual const Arc& Value_() const { return Value(); } 843 virtual void Next_() { Next(); } 844 virtual size_t Position_() const { return Position(); } 845 virtual void Reset_() { Reset(); } 846 virtual void Seek_(size_t a) { Seek(a); } 847 virtual void SetValue_(const Arc &a) { SetValue(a); } 848 uint32 Flags_() const { return Flags(); } 849 void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } 850 851 size_t i_; 852 StateId s_; 853 Impl *impl_; 854 State *state_; 855 856 DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator); 857 }; 858 859 } // namespace fst 860 861 #endif // FST_LIB_CACHE_H__ 862