1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: krr (at) google.com (Kasturi Rangan Raghavan) 16 // \file 17 // LogWeight along with sign information that represents the value X in the 18 // linear domain as <sign(X), -ln(|X|)> 19 // The sign is a TropicalWeight: 20 // positive, TropicalWeight.Value() > 0.0, recommended value 1.0 21 // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0 22 23 #ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_ 24 #define FST_LIB_SIGNED_LOG_WEIGHT_H_ 25 26 #include <fst/float-weight.h> 27 #include <fst/pair-weight.h> 28 29 30 namespace fst { 31 template <class T> 32 class SignedLogWeightTpl 33 : public PairWeight<TropicalWeight, LogWeightTpl<T> > { 34 public: 35 typedef TropicalWeight X1; 36 typedef LogWeightTpl<T> X2; 37 using PairWeight<X1, X2>::Value1; 38 using PairWeight<X1, X2>::Value2; 39 40 using PairWeight<X1, X2>::Reverse; 41 using PairWeight<X1, X2>::Quantize; 42 using PairWeight<X1, X2>::Member; 43 44 typedef SignedLogWeightTpl<T> ReverseWeight; 45 46 SignedLogWeightTpl() : PairWeight<X1, X2>() {} 47 48 SignedLogWeightTpl(const SignedLogWeightTpl<T>& w) 49 : PairWeight<X1, X2> (w) { } 50 51 SignedLogWeightTpl(const PairWeight<X1, X2>& w) 52 : PairWeight<X1, X2> (w) { } 53 54 SignedLogWeightTpl(const X1& x1, const X2& x2) 55 : PairWeight<X1, X2>(x1, x2) { } 56 57 static const SignedLogWeightTpl<T> &Zero() { 58 static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero()); 59 return zero; 60 } 61 62 static const SignedLogWeightTpl<T> &One() { 63 static const SignedLogWeightTpl<T> one(X1(1.0), X2::One()); 64 return one; 65 } 66 67 static const SignedLogWeightTpl<T> &NoWeight() { 68 static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight()); 69 return no_weight; 70 } 71 72 static const string &Type() { 73 static const string type = "signed_log_" + X1::Type() + "_" + X2::Type(); 74 return type; 75 } 76 77 ProductWeight<X1, X2> Quantize(float delta = kDelta) const { 78 return PairWeight<X1, X2>::Quantize(); 79 } 80 81 ReverseWeight Reverse() const { 82 return PairWeight<X1, X2>::Reverse(); 83 } 84 85 bool Member() const { 86 return PairWeight<X1, X2>::Member(); 87 } 88 89 static uint64 Properties() { 90 // not idempotent nor path 91 return kLeftSemiring | kRightSemiring | kCommutative; 92 } 93 94 size_t Hash() const { 95 size_t h1; 96 if (Value2() == X2::Zero() || Value1().Value() > 0.0) 97 h1 = TropicalWeight(1.0).Hash(); 98 else 99 h1 = TropicalWeight(-1.0).Hash(); 100 size_t h2 = Value2().Hash(); 101 const int lshift = 5; 102 const int rshift = CHAR_BIT * sizeof(size_t) - 5; 103 return h1 << lshift ^ h1 >> rshift ^ h2; 104 } 105 }; 106 107 template <class T> 108 inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1, 109 const SignedLogWeightTpl<T> &w2) { 110 if (!w1.Member() || !w2.Member()) 111 return SignedLogWeightTpl<T>::NoWeight(); 112 bool s1 = w1.Value1().Value() > 0.0; 113 bool s2 = w2.Value1().Value() > 0.0; 114 T f1 = w1.Value2().Value(); 115 T f2 = w2.Value2().Value(); 116 if (f1 == FloatLimits<T>::PosInfinity()) 117 return w2; 118 else if (f2 == FloatLimits<T>::PosInfinity()) 119 return w1; 120 else if (f1 == f2) { 121 if (s1 == s2) 122 return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F))); 123 else 124 return SignedLogWeightTpl<T>::Zero(); 125 } else if (f1 > f2) { 126 if (s1 == s2) { 127 return SignedLogWeightTpl<T>( 128 w1.Value1(), (f2 - log(1.0F + exp(f2 - f1)))); 129 } else { 130 return SignedLogWeightTpl<T>( 131 w2.Value1(), (f2 - log(1.0F - exp(f2 - f1)))); 132 } 133 } else { 134 if (s2 == s1) { 135 return SignedLogWeightTpl<T>( 136 w2.Value1(), (f1 - log(1.0F + exp(f1 - f2)))); 137 } else { 138 return SignedLogWeightTpl<T>( 139 w1.Value1(), (f1 - log(1.0F - exp(f1 - f2)))); 140 } 141 } 142 } 143 144 template <class T> 145 inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1, 146 const SignedLogWeightTpl<T> &w2) { 147 SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2()); 148 return Plus(w1, minus_w2); 149 } 150 151 template <class T> 152 inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1, 153 const SignedLogWeightTpl<T> &w2) { 154 if (!w1.Member() || !w2.Member()) 155 return SignedLogWeightTpl<T>::NoWeight(); 156 bool s1 = w1.Value1().Value() > 0.0; 157 bool s2 = w2.Value1().Value() > 0.0; 158 T f1 = w1.Value2().Value(); 159 T f2 = w2.Value2().Value(); 160 if (s1 == s2) 161 return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2)); 162 else 163 return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2)); 164 } 165 166 template <class T> 167 inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1, 168 const SignedLogWeightTpl<T> &w2, 169 DivideType typ = DIVIDE_ANY) { 170 if (!w1.Member() || !w2.Member()) 171 return SignedLogWeightTpl<T>::NoWeight(); 172 bool s1 = w1.Value1().Value() > 0.0; 173 bool s2 = w2.Value1().Value() > 0.0; 174 T f1 = w1.Value2().Value(); 175 T f2 = w2.Value2().Value(); 176 if (f2 == FloatLimits<T>::PosInfinity()) 177 return SignedLogWeightTpl<T>(TropicalWeight(1.0), 178 FloatLimits<T>::NumberBad()); 179 else if (f1 == FloatLimits<T>::PosInfinity()) 180 return SignedLogWeightTpl<T>(TropicalWeight(1.0), 181 FloatLimits<T>::PosInfinity()); 182 else if (s1 == s2) 183 return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2)); 184 else 185 return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2)); 186 } 187 188 template <class T> 189 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1, 190 const SignedLogWeightTpl<T> &w2, 191 float delta = kDelta) { 192 bool s1 = w1.Value1().Value() > 0.0; 193 bool s2 = w2.Value1().Value() > 0.0; 194 if (s1 == s2) { 195 return ApproxEqual(w1.Value2(), w2.Value2(), delta); 196 } else { 197 return w1.Value2() == LogWeightTpl<T>::Zero() 198 && w2.Value2() == LogWeightTpl<T>::Zero(); 199 } 200 } 201 202 template <class T> 203 inline bool operator==(const SignedLogWeightTpl<T> &w1, 204 const SignedLogWeightTpl<T> &w2) { 205 bool s1 = w1.Value1().Value() > 0.0; 206 bool s2 = w2.Value1().Value() > 0.0; 207 if (s1 == s2) 208 return w1.Value2() == w2.Value2(); 209 else 210 return (w1.Value2() == LogWeightTpl<T>::Zero()) && 211 (w2.Value2() == LogWeightTpl<T>::Zero()); 212 } 213 214 215 // Single-precision signed-log weight 216 typedef SignedLogWeightTpl<float> SignedLogWeight; 217 // Double-precision signed-log weight 218 typedef SignedLogWeightTpl<double> SignedLog64Weight; 219 220 // 221 // WEIGHT CONVERTER SPECIALIZATIONS. 222 // 223 224 template <class W1, class W2> 225 bool SignedLogConvertCheck(W1 w) { 226 if (w.Value1().Value() < 0.0) { 227 FSTERROR() << "WeightConvert: can't convert weight from \"" 228 << W1::Type() << "\" to \"" << W2::Type(); 229 return false; 230 } 231 return true; 232 } 233 234 // Convert to tropical 235 template <> 236 struct WeightConvert<SignedLogWeight, TropicalWeight> { 237 TropicalWeight operator()(SignedLogWeight w) const { 238 if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w)) 239 return TropicalWeight::NoWeight(); 240 return w.Value2().Value(); 241 } 242 }; 243 244 template <> 245 struct WeightConvert<SignedLog64Weight, TropicalWeight> { 246 TropicalWeight operator()(SignedLog64Weight w) const { 247 if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w)) 248 return TropicalWeight::NoWeight(); 249 return w.Value2().Value(); 250 } 251 }; 252 253 // Convert to log 254 template <> 255 struct WeightConvert<SignedLogWeight, LogWeight> { 256 LogWeight operator()(SignedLogWeight w) const { 257 if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w)) 258 return LogWeight::NoWeight(); 259 return w.Value2().Value(); 260 } 261 }; 262 263 template <> 264 struct WeightConvert<SignedLog64Weight, LogWeight> { 265 LogWeight operator()(SignedLog64Weight w) const { 266 if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w)) 267 return LogWeight::NoWeight(); 268 return w.Value2().Value(); 269 } 270 }; 271 272 // Convert to log64 273 template <> 274 struct WeightConvert<SignedLogWeight, Log64Weight> { 275 Log64Weight operator()(SignedLogWeight w) const { 276 if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w)) 277 return Log64Weight::NoWeight(); 278 return w.Value2().Value(); 279 } 280 }; 281 282 template <> 283 struct WeightConvert<SignedLog64Weight, Log64Weight> { 284 Log64Weight operator()(SignedLog64Weight w) const { 285 if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w)) 286 return Log64Weight::NoWeight(); 287 return w.Value2().Value(); 288 } 289 }; 290 291 // Convert to signed log 292 template <> 293 struct WeightConvert<TropicalWeight, SignedLogWeight> { 294 SignedLogWeight operator()(TropicalWeight w) const { 295 TropicalWeight x1 = 1.0; 296 LogWeight x2 = w.Value(); 297 return SignedLogWeight(x1, x2); 298 } 299 }; 300 301 template <> 302 struct WeightConvert<LogWeight, SignedLogWeight> { 303 SignedLogWeight operator()(LogWeight w) const { 304 TropicalWeight x1 = 1.0; 305 LogWeight x2 = w.Value(); 306 return SignedLogWeight(x1, x2); 307 } 308 }; 309 310 template <> 311 struct WeightConvert<Log64Weight, SignedLogWeight> { 312 SignedLogWeight operator()(Log64Weight w) const { 313 TropicalWeight x1 = 1.0; 314 LogWeight x2 = w.Value(); 315 return SignedLogWeight(x1, x2); 316 } 317 }; 318 319 template <> 320 struct WeightConvert<SignedLog64Weight, SignedLogWeight> { 321 SignedLogWeight operator()(SignedLog64Weight w) const { 322 TropicalWeight x1 = w.Value1(); 323 LogWeight x2 = w.Value2().Value(); 324 return SignedLogWeight(x1, x2); 325 } 326 }; 327 328 // Convert to signed log64 329 template <> 330 struct WeightConvert<TropicalWeight, SignedLog64Weight> { 331 SignedLog64Weight operator()(TropicalWeight w) const { 332 TropicalWeight x1 = 1.0; 333 Log64Weight x2 = w.Value(); 334 return SignedLog64Weight(x1, x2); 335 } 336 }; 337 338 template <> 339 struct WeightConvert<LogWeight, SignedLog64Weight> { 340 SignedLog64Weight operator()(LogWeight w) const { 341 TropicalWeight x1 = 1.0; 342 Log64Weight x2 = w.Value(); 343 return SignedLog64Weight(x1, x2); 344 } 345 }; 346 347 template <> 348 struct WeightConvert<Log64Weight, SignedLog64Weight> { 349 SignedLog64Weight operator()(Log64Weight w) const { 350 TropicalWeight x1 = 1.0; 351 Log64Weight x2 = w.Value(); 352 return SignedLog64Weight(x1, x2); 353 } 354 }; 355 356 template <> 357 struct WeightConvert<SignedLogWeight, SignedLog64Weight> { 358 SignedLog64Weight operator()(SignedLogWeight w) const { 359 TropicalWeight x1 = w.Value1(); 360 Log64Weight x2 = w.Value2().Value(); 361 return SignedLog64Weight(x1, x2); 362 } 363 }; 364 365 } // namespace fst 366 367 #endif // FST_LIB_SIGNED_LOG_WEIGHT_H_ 368