Home | History | Annotate | Download | only in fst
      1 // string-weight.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // String weight set and associated semiring operation definitions.
     20 
     21 #ifndef FST_LIB_STRING_WEIGHT_H__
     22 #define FST_LIB_STRING_WEIGHT_H__
     23 
     24 #include <list>
     25 #include <string>
     26 
     27 #include <fst/product-weight.h>
     28 #include <fst/weight.h>
     29 
     30 namespace fst {
     31 
     32 const int kStringInfinity = -1;      // Label for the infinite string
     33 const int kStringBad = -2;           // Label for a non-string
     34 const char kStringSeparator = '_';   // Label separator in strings
     35 
     36 // Determines whether to use left or right string semiring.  Includes
     37 // restricted versions that signal an error if proper prefixes
     38 // (suffixes) would otherwise be returned by Plus, useful with various
     39 // algorithms that require functional transducer input with the
     40 // string semirings.
     41 enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1 ,
     42                   STRING_LEFT_RESTRICT = 2, STRING_RIGHT_RESTRICT };
     43 
     44 #define REVERSE_STRING_TYPE(S)                                  \
     45    ((S) == STRING_LEFT ? STRING_RIGHT :                         \
     46     ((S) == STRING_RIGHT ? STRING_LEFT :                        \
     47      ((S) == STRING_LEFT_RESTRICT ? STRING_RIGHT_RESTRICT :     \
     48       STRING_LEFT_RESTRICT)))
     49 
     50 template <typename L, StringType S = STRING_LEFT>
     51 class StringWeight;
     52 
     53 template <typename L, StringType S = STRING_LEFT>
     54 class StringWeightIterator;
     55 
     56 template <typename L, StringType S = STRING_LEFT>
     57 class StringWeightReverseIterator;
     58 
     59 template <typename L, StringType S>
     60 bool operator==(const StringWeight<L, S> &,  const StringWeight<L, S> &);
     61 
     62 
     63 // String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
     64 template <typename L, StringType S>
     65 class StringWeight {
     66  public:
     67   typedef L Label;
     68   typedef StringWeight<L, REVERSE_STRING_TYPE(S)> ReverseWeight;
     69 
     70   friend class StringWeightIterator<L, S>;
     71   friend class StringWeightReverseIterator<L, S>;
     72   friend bool operator==<>(const StringWeight<L, S> &,
     73                            const StringWeight<L, S> &);
     74 
     75   StringWeight() { Init(); }
     76 
     77   template <typename Iter>
     78   StringWeight(const Iter &begin, const Iter &end) {
     79     Init();
     80     for (Iter iter = begin; iter != end; ++iter)
     81       PushBack(*iter);
     82   }
     83 
     84   explicit StringWeight(L l) { Init(); PushBack(l); }
     85 
     86   static const StringWeight<L, S> &Zero() {
     87     static const StringWeight<L, S> zero(kStringInfinity);
     88     return zero;
     89   }
     90 
     91   static const StringWeight<L, S> &One() {
     92     static const StringWeight<L, S> one;
     93     return one;
     94   }
     95 
     96   static const StringWeight<L, S> &NoWeight() {
     97     static const StringWeight<L, S> no_weight(kStringBad);
     98     return no_weight;
     99   }
    100 
    101   static const string &Type() {
    102     static const string type =
    103         S == STRING_LEFT ? "string" :
    104         (S == STRING_RIGHT ? "right_string" :
    105          (S == STRING_LEFT_RESTRICT ? "restricted_string" :
    106           "right_restricted_string"));
    107     return type;
    108   }
    109 
    110   bool Member() const;
    111 
    112   istream &Read(istream &strm);
    113 
    114   ostream &Write(ostream &strm) const;
    115 
    116   size_t Hash() const;
    117 
    118   StringWeight<L, S> Quantize(float delta = kDelta) const {
    119     return *this;
    120   }
    121 
    122   ReverseWeight Reverse() const;
    123 
    124   static uint64 Properties() {
    125     return (S == STRING_LEFT || S == STRING_LEFT_RESTRICT ?
    126             kLeftSemiring : kRightSemiring) | kIdempotent;
    127   }
    128 
    129   // NB: This needs to be uncommented only if default fails for this impl.
    130   // StringWeight<L, S> &operator=(const StringWeight<L, S> &w);
    131 
    132   // These operations combined with the StringWeightIterator and
    133   // StringWeightReverseIterator provide the access and mutation of
    134   // the string internal elements.
    135 
    136   // Common initializer among constructors.
    137   void Init() { first_ = 0; }
    138 
    139   // Clear existing StringWeight.
    140   void Clear() { first_ = 0; rest_.clear(); }
    141 
    142   size_t Size() const { return first_ ? rest_.size() + 1 : 0; }
    143 
    144   void PushFront(L l) {
    145     if (first_)
    146       rest_.push_front(first_);
    147     first_ = l;
    148   }
    149 
    150   void PushBack(L l) {
    151     if (!first_)
    152       first_ = l;
    153     else
    154       rest_.push_back(l);
    155   }
    156 
    157  private:
    158   L first_;         // first label in string (0 if empty)
    159   list<L> rest_;    // remaining labels in string
    160 };
    161 
    162 
    163 // Traverses string in forward direction.
    164 template <typename L, StringType S>
    165 class StringWeightIterator {
    166  public:
    167   explicit StringWeightIterator(const StringWeight<L, S>& w)
    168       : first_(w.first_), rest_(w.rest_), init_(true),
    169         iter_(rest_.begin()) {}
    170 
    171   bool Done() const {
    172     if (init_) return first_ == 0;
    173     else return iter_ == rest_.end();
    174   }
    175 
    176   const L& Value() const { return init_ ? first_ : *iter_; }
    177 
    178   void Next() {
    179     if (init_) init_ = false;
    180     else  ++iter_;
    181   }
    182 
    183   void Reset() {
    184     init_ = true;
    185     iter_ = rest_.begin();
    186   }
    187 
    188  private:
    189   const L &first_;
    190   const list<L> &rest_;
    191   bool init_;   // in the initialized state?
    192   typename list<L>::const_iterator iter_;
    193 
    194   DISALLOW_COPY_AND_ASSIGN(StringWeightIterator);
    195 };
    196 
    197 
    198 // Traverses string in backward direction.
    199 template <typename L, StringType S>
    200 class StringWeightReverseIterator {
    201  public:
    202   explicit StringWeightReverseIterator(const StringWeight<L, S>& w)
    203       : first_(w.first_), rest_(w.rest_), fin_(first_ == 0),
    204         iter_(rest_.rbegin()) {}
    205 
    206   bool Done() const { return fin_; }
    207 
    208   const L& Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
    209 
    210   void Next() {
    211     if (iter_ == rest_.rend()) fin_ = true;
    212     else  ++iter_;
    213   }
    214 
    215   void Reset() {
    216     fin_ = false;
    217     iter_ = rest_.rbegin();
    218   }
    219 
    220  private:
    221   const L &first_;
    222   const list<L> &rest_;
    223   bool fin_;   // in the final state?
    224   typename list<L>::const_reverse_iterator iter_;
    225 
    226   DISALLOW_COPY_AND_ASSIGN(StringWeightReverseIterator);
    227 };
    228 
    229 
    230 // StringWeight member functions follow that require
    231 // StringWeightIterator or StringWeightReverseIterator.
    232 
    233 template <typename L, StringType S>
    234 inline istream &StringWeight<L, S>::Read(istream &strm) {
    235   Clear();
    236   int32 size;
    237   ReadType(strm, &size);
    238   for (int i = 0; i < size; ++i) {
    239     L label;
    240     ReadType(strm, &label);
    241     PushBack(label);
    242   }
    243   return strm;
    244 }
    245 
    246 template <typename L, StringType S>
    247 inline ostream &StringWeight<L, S>::Write(ostream &strm) const {
    248   int32 size =  Size();
    249   WriteType(strm, size);
    250   for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) {
    251     L label = iter.Value();
    252     WriteType(strm, label);
    253   }
    254   return strm;
    255 }
    256 
    257 template <typename L, StringType S>
    258 inline bool StringWeight<L, S>::Member() const {
    259   if (Size() != 1)
    260     return true;
    261   StringWeightIterator<L, S> iter(*this);
    262   return iter.Value() != kStringBad;
    263 }
    264 
    265 template <typename L, StringType S>
    266 inline typename StringWeight<L, S>::ReverseWeight
    267 StringWeight<L, S>::Reverse() const {
    268   ReverseWeight rw;
    269   for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
    270     rw.PushFront(iter.Value());
    271   return rw;
    272 }
    273 
    274 template <typename L, StringType S>
    275 inline size_t StringWeight<L, S>::Hash() const {
    276   size_t h = 0;
    277   for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
    278     h ^= h<<1 ^ iter.Value();
    279   return h;
    280 }
    281 
    282 // NB: This needs to be uncommented only if default fails for this the impl.
    283 //
    284 // template <typename L, StringType S>
    285 // inline StringWeight<L, S>
    286 // &StringWeight<L, S>::operator=(const StringWeight<L, S> &w) {
    287 //   if (this != &w) {
    288 //     Clear();
    289 //     for (StringWeightIterator<L, S> iter(w); !iter.Done(); iter.Next())
    290 //       PushBack(iter.Value());
    291 //   }
    292 //   return *this;
    293 // }
    294 
    295 template <typename L, StringType S>
    296 inline bool operator==(const StringWeight<L, S> &w1,
    297                        const StringWeight<L, S> &w2) {
    298   if (w1.Size() != w2.Size())
    299     return false;
    300 
    301   StringWeightIterator<L, S> iter1(w1);
    302   StringWeightIterator<L, S> iter2(w2);
    303 
    304   for (; !iter1.Done() ; iter1.Next(), iter2.Next())
    305     if (iter1.Value() != iter2.Value())
    306       return false;
    307 
    308   return true;
    309 }
    310 
    311 template <typename L, StringType S>
    312 inline bool operator!=(const StringWeight<L, S> &w1,
    313                        const StringWeight<L, S> &w2) {
    314   return !(w1 == w2);
    315 }
    316 
    317 template <typename L, StringType S>
    318 inline bool ApproxEqual(const StringWeight<L, S> &w1,
    319                         const StringWeight<L, S> &w2,
    320                         float delta = kDelta) {
    321   return w1 == w2;
    322 }
    323 
    324 template <typename L, StringType S>
    325 inline ostream &operator<<(ostream &strm, const StringWeight<L, S> &w) {
    326   StringWeightIterator<L, S> iter(w);
    327   if (iter.Done())
    328     return strm << "Epsilon";
    329   else if (iter.Value() == kStringInfinity)
    330     return strm << "Infinity";
    331   else if (iter.Value() == kStringBad)
    332     return strm << "BadString";
    333   else
    334     for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
    335       if (i > 0)
    336         strm << kStringSeparator;
    337       strm << iter.Value();
    338     }
    339   return strm;
    340 }
    341 
    342 template <typename L, StringType S>
    343 inline istream &operator>>(istream &strm, StringWeight<L, S> &w) {
    344   string s;
    345   strm >> s;
    346   if (s == "Infinity") {
    347     w = StringWeight<L, S>::Zero();
    348   } else if (s == "Epsilon") {
    349     w = StringWeight<L, S>::One();
    350   } else {
    351     w.Clear();
    352     char *p = 0;
    353     for (const char *cs = s.c_str(); !p || *p != '\0'; cs = p + 1) {
    354       int l = strtoll(cs, &p, 10);
    355       if (p == cs || (*p != 0 && *p != kStringSeparator)) {
    356         strm.clear(std::ios::badbit);
    357         break;
    358       }
    359       w.PushBack(l);
    360     }
    361   }
    362   return strm;
    363 }
    364 
    365 
    366 // Default is for the restricted left and right semirings.  String
    367 // equality is required (for non-Zero() input. This restriction
    368 // is used in e.g. Determinize to ensure functional input.
    369 template <typename L, StringType S>  inline StringWeight<L, S>
    370 Plus(const StringWeight<L, S> &w1,
    371      const StringWeight<L, S> &w2) {
    372   if (!w1.Member() || !w2.Member())
    373     return StringWeight<L, S>::NoWeight();
    374   if (w1 == StringWeight<L, S>::Zero())
    375     return w2;
    376   if (w2 == StringWeight<L, S>::Zero())
    377     return w1;
    378 
    379   if (w1 != w2) {
    380     FSTERROR() << "StringWeight::Plus: unequal arguments "
    381                << "(non-functional FST?)"
    382                << " w1 = " << w1
    383                << " w2 = " << w2;
    384     return StringWeight<L, S>::NoWeight();
    385   }
    386 
    387   return w1;
    388 }
    389 
    390 
    391 // Longest common prefix for left string semiring.
    392 template <typename L>  inline StringWeight<L, STRING_LEFT>
    393 Plus(const StringWeight<L, STRING_LEFT> &w1,
    394      const StringWeight<L, STRING_LEFT> &w2) {
    395   if (!w1.Member() || !w2.Member())
    396     return StringWeight<L, STRING_LEFT>::NoWeight();
    397   if (w1 == StringWeight<L, STRING_LEFT>::Zero())
    398     return w2;
    399   if (w2 == StringWeight<L, STRING_LEFT>::Zero())
    400     return w1;
    401 
    402   StringWeight<L, STRING_LEFT> sum;
    403   StringWeightIterator<L, STRING_LEFT> iter1(w1);
    404   StringWeightIterator<L, STRING_LEFT> iter2(w2);
    405   for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
    406        iter1.Next(), iter2.Next())
    407     sum.PushBack(iter1.Value());
    408   return sum;
    409 }
    410 
    411 
    412 // Longest common suffix for right string semiring.
    413 template <typename L>  inline StringWeight<L, STRING_RIGHT>
    414 Plus(const StringWeight<L, STRING_RIGHT> &w1,
    415      const StringWeight<L, STRING_RIGHT> &w2) {
    416   if (!w1.Member() || !w2.Member())
    417     return StringWeight<L, STRING_RIGHT>::NoWeight();
    418   if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
    419     return w2;
    420   if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
    421     return w1;
    422 
    423   StringWeight<L, STRING_RIGHT> sum;
    424   StringWeightReverseIterator<L, STRING_RIGHT> iter1(w1);
    425   StringWeightReverseIterator<L, STRING_RIGHT> iter2(w2);
    426   for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
    427        iter1.Next(), iter2.Next())
    428     sum.PushFront(iter1.Value());
    429   return sum;
    430 }
    431 
    432 
    433 template <typename L, StringType S>
    434 inline StringWeight<L, S> Times(const StringWeight<L, S> &w1,
    435                              const StringWeight<L, S> &w2) {
    436   if (!w1.Member() || !w2.Member())
    437     return StringWeight<L, S>::NoWeight();
    438   if (w1 == StringWeight<L, S>::Zero() || w2 == StringWeight<L, S>::Zero())
    439     return StringWeight<L, S>::Zero();
    440 
    441   StringWeight<L, S> prod(w1);
    442   for (StringWeightIterator<L, S> iter(w2); !iter.Done(); iter.Next())
    443     prod.PushBack(iter.Value());
    444 
    445   return prod;
    446 }
    447 
    448 
    449 // Default is for left division in the left string and the
    450 // left restricted string semirings.
    451 template <typename L, StringType S> inline StringWeight<L, S>
    452 Divide(const StringWeight<L, S> &w1,
    453        const StringWeight<L, S> &w2,
    454        DivideType typ) {
    455 
    456   if (typ != DIVIDE_LEFT) {
    457     FSTERROR() << "StringWeight::Divide: only left division is defined "
    458                << "for the " << StringWeight<L, S>::Type() << " semiring";
    459     return StringWeight<L, S>::NoWeight();
    460   }
    461 
    462   if (!w1.Member() || !w2.Member())
    463     return StringWeight<L, S>::NoWeight();
    464 
    465   if (w2 == StringWeight<L, S>::Zero())
    466     return StringWeight<L, S>(kStringBad);
    467   else if (w1 == StringWeight<L, S>::Zero())
    468     return StringWeight<L, S>::Zero();
    469 
    470   StringWeight<L, S> div;
    471   StringWeightIterator<L, S> iter(w1);
    472   for (int i = 0; !iter.Done(); iter.Next(), ++i) {
    473     if (i >= w2.Size())
    474       div.PushBack(iter.Value());
    475   }
    476   return div;
    477 }
    478 
    479 
    480 // Right division in the right string semiring.
    481 template <typename L> inline StringWeight<L, STRING_RIGHT>
    482 Divide(const StringWeight<L, STRING_RIGHT> &w1,
    483        const StringWeight<L, STRING_RIGHT> &w2,
    484        DivideType typ) {
    485 
    486   if (typ != DIVIDE_RIGHT) {
    487     FSTERROR() << "StringWeight::Divide: only right division is defined "
    488                << "for the right string semiring";
    489     return StringWeight<L, STRING_RIGHT>::NoWeight();
    490   }
    491 
    492   if (!w1.Member() || !w2.Member())
    493     return StringWeight<L, STRING_RIGHT>::NoWeight();
    494 
    495   if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
    496     return StringWeight<L, STRING_RIGHT>(kStringBad);
    497   else if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
    498     return StringWeight<L, STRING_RIGHT>::Zero();
    499 
    500   StringWeight<L, STRING_RIGHT> div;
    501   StringWeightReverseIterator<L, STRING_RIGHT> iter(w1);
    502   for (int i = 0; !iter.Done(); iter.Next(), ++i) {
    503     if (i >= w2.Size())
    504       div.PushFront(iter.Value());
    505   }
    506   return div;
    507 }
    508 
    509 
    510 // Right division in the right restricted string semiring.
    511 template <typename L> inline StringWeight<L, STRING_RIGHT_RESTRICT>
    512 Divide(const StringWeight<L, STRING_RIGHT_RESTRICT> &w1,
    513        const StringWeight<L, STRING_RIGHT_RESTRICT> &w2,
    514        DivideType typ) {
    515 
    516   if (typ != DIVIDE_RIGHT) {
    517     FSTERROR() << "StringWeight::Divide: only right division is defined "
    518                << "for the right restricted string semiring";
    519     return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight();
    520   }
    521 
    522   if (!w1.Member() || !w2.Member())
    523     return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight();
    524 
    525   if (w2 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
    526     return StringWeight<L, STRING_RIGHT_RESTRICT>(kStringBad);
    527   else if (w1 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
    528     return StringWeight<L, STRING_RIGHT_RESTRICT>::Zero();
    529 
    530   StringWeight<L, STRING_RIGHT_RESTRICT> div;
    531   StringWeightReverseIterator<L, STRING_RIGHT_RESTRICT> iter(w1);
    532   for (int i = 0; !iter.Done(); iter.Next(), ++i) {
    533     if (i >= w2.Size())
    534       div.PushFront(iter.Value());
    535   }
    536   return div;
    537 }
    538 
    539 
    540 // Product of string weight and an arbitray weight.
    541 template <class L, class W, StringType S = STRING_LEFT>
    542 struct GallicWeight : public ProductWeight<StringWeight<L, S>, W> {
    543   typedef GallicWeight<L, typename W::ReverseWeight, REVERSE_STRING_TYPE(S)>
    544   ReverseWeight;
    545 
    546   GallicWeight() {}
    547 
    548   GallicWeight(StringWeight<L, S> w1, W w2)
    549       : ProductWeight<StringWeight<L, S>, W>(w1, w2) {}
    550 
    551   explicit GallicWeight(const string &s, int *nread = 0)
    552       : ProductWeight<StringWeight<L, S>, W>(s, nread) {}
    553 
    554   GallicWeight(const ProductWeight<StringWeight<L, S>, W> &w)
    555       : ProductWeight<StringWeight<L, S>, W>(w) {}
    556 };
    557 
    558 }  // namespace fst
    559 
    560 #endif  // FST_LIB_STRING_WEIGHT_H__
    561