Home | History | Annotate | Download | only in fst
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Author: krr (at) google.com (Kasturi Rangan Raghavan)
     16 // \file
     17 // LogWeight along with sign information that represents the value X in the
     18 // linear domain as <sign(X), -ln(|X|)>
     19 // The sign is a TropicalWeight:
     20 //  positive, TropicalWeight.Value() > 0.0, recommended value 1.0
     21 //  negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
     22 
     23 #ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_
     24 #define FST_LIB_SIGNED_LOG_WEIGHT_H_
     25 
     26 #include <fst/float-weight.h>
     27 #include <fst/pair-weight.h>
     28 
     29 
     30 namespace fst {
     31 template <class T>
     32 class SignedLogWeightTpl
     33     : public PairWeight<TropicalWeight, LogWeightTpl<T> > {
     34  public:
     35   typedef TropicalWeight X1;
     36   typedef LogWeightTpl<T> X2;
     37   using PairWeight<X1, X2>::Value1;
     38   using PairWeight<X1, X2>::Value2;
     39 
     40   using PairWeight<X1, X2>::Reverse;
     41   using PairWeight<X1, X2>::Quantize;
     42   using PairWeight<X1, X2>::Member;
     43 
     44   typedef SignedLogWeightTpl<T> ReverseWeight;
     45 
     46   SignedLogWeightTpl() : PairWeight<X1, X2>() {}
     47 
     48   SignedLogWeightTpl(const SignedLogWeightTpl<T>& w)
     49       : PairWeight<X1, X2> (w) { }
     50 
     51   SignedLogWeightTpl(const PairWeight<X1, X2>& w)
     52       : PairWeight<X1, X2> (w) { }
     53 
     54   SignedLogWeightTpl(const X1& x1, const X2& x2)
     55       : PairWeight<X1, X2>(x1, x2) { }
     56 
     57   static const SignedLogWeightTpl<T> &Zero() {
     58     static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero());
     59     return zero;
     60   }
     61 
     62   static const SignedLogWeightTpl<T> &One() {
     63     static const SignedLogWeightTpl<T> one(X1(1.0), X2::One());
     64     return one;
     65   }
     66 
     67   static const SignedLogWeightTpl<T> &NoWeight() {
     68     static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight());
     69     return no_weight;
     70   }
     71 
     72   static const string &Type() {
     73     static const string type = "signed_log_" + X1::Type() + "_" + X2::Type();
     74     return type;
     75   }
     76 
     77   ProductWeight<X1, X2> Quantize(float delta = kDelta) const {
     78     return PairWeight<X1, X2>::Quantize();
     79   }
     80 
     81   ReverseWeight Reverse() const {
     82     return PairWeight<X1, X2>::Reverse();
     83   }
     84 
     85   bool Member() const {
     86     return PairWeight<X1, X2>::Member();
     87   }
     88 
     89   static uint64 Properties() {
     90     // not idempotent nor path
     91     return kLeftSemiring | kRightSemiring | kCommutative;
     92   }
     93 
     94   size_t Hash() const {
     95     size_t h1;
     96     if (Value2() == X2::Zero() || Value1().Value() > 0.0)
     97       h1 = TropicalWeight(1.0).Hash();
     98     else
     99       h1 = TropicalWeight(-1.0).Hash();
    100     size_t h2 = Value2().Hash();
    101     const int lshift = 5;
    102     const int rshift = CHAR_BIT * sizeof(size_t) - 5;
    103     return h1 << lshift ^ h1 >> rshift ^ h2;
    104   }
    105 };
    106 
    107 template <class T>
    108 inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
    109                                   const SignedLogWeightTpl<T> &w2) {
    110   if (!w1.Member() || !w2.Member())
    111     return SignedLogWeightTpl<T>::NoWeight();
    112   bool s1 = w1.Value1().Value() > 0.0;
    113   bool s2 = w2.Value1().Value() > 0.0;
    114   T f1 = w1.Value2().Value();
    115   T f2 = w2.Value2().Value();
    116   if (f1 == FloatLimits<T>::kPosInfinity)
    117     return w2;
    118   else if (f2 == FloatLimits<T>::kPosInfinity)
    119     return w1;
    120   else if (f1 == f2) {
    121     if (s1 == s2)
    122       return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F)));
    123     else
    124       return SignedLogWeightTpl<T>::Zero();
    125   } else if (f1 > f2) {
    126     if (s1 == s2) {
    127       return SignedLogWeightTpl<T>(
    128         w1.Value1(), (f2 - log(1.0F + exp(f2 - f1))));
    129     } else {
    130       return SignedLogWeightTpl<T>(
    131         w2.Value1(), (f2 - log(1.0F - exp(f2 - f1))));
    132     }
    133   } else {
    134     if (s2 == s1) {
    135       return SignedLogWeightTpl<T>(
    136         w2.Value1(), (f1 - log(1.0F + exp(f1 - f2))));
    137     } else {
    138       return SignedLogWeightTpl<T>(
    139         w1.Value1(), (f1 - log(1.0F - exp(f1 - f2))));
    140     }
    141   }
    142 }
    143 
    144 template <class T>
    145 inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
    146                                    const SignedLogWeightTpl<T> &w2) {
    147   SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
    148   return Plus(w1, minus_w2);
    149 }
    150 
    151 template <class T>
    152 inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
    153                                    const SignedLogWeightTpl<T> &w2) {
    154   if (!w1.Member() || !w2.Member())
    155     return SignedLogWeightTpl<T>::NoWeight();
    156   bool s1 = w1.Value1().Value() > 0.0;
    157   bool s2 = w2.Value1().Value() > 0.0;
    158   T f1 = w1.Value2().Value();
    159   T f2 = w2.Value2().Value();
    160   if (s1 == s2)
    161     return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2));
    162   else
    163     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2));
    164 }
    165 
    166 template <class T>
    167 inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
    168                                     const SignedLogWeightTpl<T> &w2,
    169                                     DivideType typ = DIVIDE_ANY) {
    170   if (!w1.Member() || !w2.Member())
    171     return SignedLogWeightTpl<T>::NoWeight();
    172   bool s1 = w1.Value1().Value() > 0.0;
    173   bool s2 = w2.Value1().Value() > 0.0;
    174   T f1 = w1.Value2().Value();
    175   T f2 = w2.Value2().Value();
    176   if (f2 == FloatLimits<T>::kPosInfinity)
    177     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
    178       FloatLimits<T>::kNumberBad);
    179   else if (f1 == FloatLimits<T>::kPosInfinity)
    180     return SignedLogWeightTpl<T>(TropicalWeight(1.0),
    181       FloatLimits<T>::kPosInfinity);
    182   else if (s1 == s2)
    183     return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2));
    184   else
    185     return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2));
    186 }
    187 
    188 template <class T>
    189 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
    190                         const SignedLogWeightTpl<T> &w2,
    191                         float delta = kDelta) {
    192   bool s1 = w1.Value1().Value() > 0.0;
    193   bool s2 = w2.Value1().Value() > 0.0;
    194   if (s1 == s2) {
    195     return ApproxEqual(w1.Value2(), w2.Value2(), delta);
    196   } else {
    197     return w1.Value2() == LogWeightTpl<T>::Zero()
    198         && w2.Value2() == LogWeightTpl<T>::Zero();
    199   }
    200 }
    201 
    202 template <class T>
    203 inline bool operator==(const SignedLogWeightTpl<T> &w1,
    204                        const SignedLogWeightTpl<T> &w2) {
    205   bool s1 = w1.Value1().Value() > 0.0;
    206   bool s2 = w2.Value1().Value() > 0.0;
    207   if (s1 == s2)
    208     return w1.Value2() == w2.Value2();
    209   else
    210     return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
    211            (w2.Value2() == LogWeightTpl<T>::Zero());
    212 }
    213 
    214 
    215 // Single-precision signed-log weight
    216 typedef SignedLogWeightTpl<float> SignedLogWeight;
    217 // Double-precision signed-log weight
    218 typedef SignedLogWeightTpl<double> SignedLog64Weight;
    219 
    220 //
    221 // WEIGHT CONVERTER SPECIALIZATIONS.
    222 //
    223 
    224 template <class W1, class W2>
    225 bool SignedLogConvertCheck(W1 w) {
    226   if (w.Value1().Value() < 0.0) {
    227     FSTERROR() << "WeightConvert: can't convert weight from \""
    228                << W1::Type() << "\" to \"" << W2::Type();
    229     return false;
    230   }
    231   return true;
    232 }
    233 
    234 // Convert to tropical
    235 template <>
    236 struct WeightConvert<SignedLogWeight, TropicalWeight> {
    237   TropicalWeight operator()(SignedLogWeight w) const {
    238     if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w))
    239       return TropicalWeight::NoWeight();
    240     return w.Value2().Value();
    241   }
    242 };
    243 
    244 template <>
    245 struct WeightConvert<SignedLog64Weight, TropicalWeight> {
    246   TropicalWeight operator()(SignedLog64Weight w) const {
    247     if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w))
    248       return TropicalWeight::NoWeight();
    249     return w.Value2().Value();
    250   }
    251 };
    252 
    253 // Convert to log
    254 template <>
    255 struct WeightConvert<SignedLogWeight, LogWeight> {
    256   LogWeight operator()(SignedLogWeight w) const {
    257     if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w))
    258       return LogWeight::NoWeight();
    259     return w.Value2().Value();
    260   }
    261 };
    262 
    263 template <>
    264 struct WeightConvert<SignedLog64Weight, LogWeight> {
    265   LogWeight operator()(SignedLog64Weight w) const {
    266     if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w))
    267       return LogWeight::NoWeight();
    268     return w.Value2().Value();
    269   }
    270 };
    271 
    272 // Convert to log64
    273 template <>
    274 struct WeightConvert<SignedLogWeight, Log64Weight> {
    275   Log64Weight operator()(SignedLogWeight w) const {
    276     if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w))
    277       return Log64Weight::NoWeight();
    278     return w.Value2().Value();
    279   }
    280 };
    281 
    282 template <>
    283 struct WeightConvert<SignedLog64Weight, Log64Weight> {
    284   Log64Weight operator()(SignedLog64Weight w) const {
    285     if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w))
    286       return Log64Weight::NoWeight();
    287     return w.Value2().Value();
    288   }
    289 };
    290 
    291 // Convert to signed log
    292 template <>
    293 struct WeightConvert<TropicalWeight, SignedLogWeight> {
    294   SignedLogWeight operator()(TropicalWeight w) const {
    295     TropicalWeight x1 = 1.0;
    296     LogWeight x2 = w.Value();
    297     return SignedLogWeight(x1, x2);
    298   }
    299 };
    300 
    301 template <>
    302 struct WeightConvert<LogWeight, SignedLogWeight> {
    303   SignedLogWeight operator()(LogWeight w) const {
    304     TropicalWeight x1 = 1.0;
    305     LogWeight x2 = w.Value();
    306     return SignedLogWeight(x1, x2);
    307   }
    308 };
    309 
    310 template <>
    311 struct WeightConvert<Log64Weight, SignedLogWeight> {
    312   SignedLogWeight operator()(Log64Weight w) const {
    313     TropicalWeight x1 = 1.0;
    314     LogWeight x2 = w.Value();
    315     return SignedLogWeight(x1, x2);
    316   }
    317 };
    318 
    319 template <>
    320 struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
    321   SignedLogWeight operator()(SignedLog64Weight w) const {
    322     TropicalWeight x1 = w.Value1();
    323     LogWeight x2 = w.Value2().Value();
    324     return SignedLogWeight(x1, x2);
    325   }
    326 };
    327 
    328 // Convert to signed log64
    329 template <>
    330 struct WeightConvert<TropicalWeight, SignedLog64Weight> {
    331   SignedLog64Weight operator()(TropicalWeight w) const {
    332     TropicalWeight x1 = 1.0;
    333     Log64Weight x2 = w.Value();
    334     return SignedLog64Weight(x1, x2);
    335   }
    336 };
    337 
    338 template <>
    339 struct WeightConvert<LogWeight, SignedLog64Weight> {
    340   SignedLog64Weight operator()(LogWeight w) const {
    341     TropicalWeight x1 = 1.0;
    342     Log64Weight x2 = w.Value();
    343     return SignedLog64Weight(x1, x2);
    344   }
    345 };
    346 
    347 template <>
    348 struct WeightConvert<Log64Weight, SignedLog64Weight> {
    349   SignedLog64Weight operator()(Log64Weight w) const {
    350     TropicalWeight x1 = 1.0;
    351     Log64Weight x2 = w.Value();
    352     return SignedLog64Weight(x1, x2);
    353   }
    354 };
    355 
    356 template <>
    357 struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
    358   SignedLog64Weight operator()(SignedLogWeight w) const {
    359     TropicalWeight x1 = w.Value1();
    360     Log64Weight x2 = w.Value2().Value();
    361     return SignedLog64Weight(x1, x2);
    362   }
    363 };
    364 
    365 }  // namespace fst
    366 
    367 #endif  // FST_LIB_SIGNED_LOG_WEIGHT_H_
    368