1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ 17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ 18 19 #include <algorithm> 20 #include <cmath> 21 #include <limits> 22 #include <memory> 23 #include <vector> 24 25 #include "third_party/eigen3/Eigen/Core" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/top_n.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/ctc/ctc_beam_entry.h" 33 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h" 34 #include "tensorflow/core/util/ctc/ctc_decoder.h" 35 #include "tensorflow/core/util/ctc/ctc_loss_util.h" 36 37 namespace tensorflow { 38 namespace ctc { 39 40 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState, 41 typename CTCBeamComparer = 42 ctc_beam_search::BeamComparer<CTCBeamState>> 43 class CTCBeamSearchDecoder : public CTCDecoder { 44 // Beam Search 45 // 46 // Example (GravesTh Fig. 7.5): 47 // a - 48 // P = [ 0.3 0.7 ] t = 0 49 // [ 0.4 0.6 ] t = 1 50 // 51 // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 52 // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 53 // 54 // In this case, Best Path decoding is suboptimal. 55 // 56 // For Beam Search, we use the following main recurrence relations: 57 // 58 // Relation 1: 59 // ---------------------------------------------------------- Eq. 1 60 // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) 61 // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) 62 // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and 63 // updated recursively in the beam entry. 64 // 65 // Relation 2: 66 // ---------------------------------------------------------- Eq. 2 67 // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) 68 // for ? in a, b, d, ..., (not including c or the blank index), 69 // and the recurrence starts from the beam entry for P(l=abc @ t=2). 70 // 71 // For this case, the length of the new sequence equals t+1 (t 72 // starts at 0). This special case can be calculated as: 73 // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) 74 // but we calculate it recursively for speed purposes. 75 typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry; 76 typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot; 77 typedef ctc_beam_search::BeamProbability BeamProbability; 78 79 public: 80 typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer; 81 82 // The beam search decoder is constructed specifying the beam_width (number of 83 // candidates to keep at each decoding timestep) and a beam scorer (used for 84 // custom scoring, for example enabling the use of a language model). 85 // The ownership of the scorer remains with the caller. The default 86 // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the 87 // standard beam search. 88 CTCBeamSearchDecoder(int num_classes, int beam_width, 89 BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1, 90 bool merge_repeated = false) 91 : CTCDecoder(num_classes, batch_size, merge_repeated), 92 beam_width_(beam_width), 93 leaves_(beam_width), 94 beam_scorer_(CHECK_NOTNULL(scorer)) { 95 Reset(); 96 } 97 98 ~CTCBeamSearchDecoder() override {} 99 100 // Run the hibernating beam search algorithm on the given input. 101 Status Decode(const CTCDecoder::SequenceLength& seq_len, 102 const std::vector<CTCDecoder::Input>& input, 103 std::vector<CTCDecoder::Output>* output, 104 CTCDecoder::ScoreOutput* scores) override; 105 106 // Calculate the next step of the beam search and update the internal state. 107 template <typename Vector> 108 void Step(const Vector& log_input_t); 109 110 template <typename Vector> 111 float GetTopK(const int K, const Vector& input, 112 std::vector<float>* top_k_logits, 113 std::vector<int>* top_k_indices); 114 115 // Retrieve the beam scorer instance used during decoding. 116 BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; } 117 118 // Set label selection parameters for faster decoding. 119 // See comments for label_selection_size_ and label_selection_margin_. 120 void SetLabelSelectionParameters(int label_selection_size, 121 float label_selection_margin) { 122 label_selection_size_ = label_selection_size; 123 label_selection_margin_ = label_selection_margin; 124 } 125 126 // Reset the beam search 127 void Reset(); 128 129 // Extract the top n paths at current time step 130 Status TopPaths(int n, std::vector<std::vector<int>>* paths, 131 std::vector<float>* log_probs, bool merge_repeated) const; 132 133 private: 134 int beam_width_; 135 136 // Label selection is designed to avoid possibly very expensive scorer calls, 137 // by pruning the hypotheses based on the input alone. 138 // Label selection size controls how many items in each beam are passed 139 // through to the beam scorer. Only items with top N input scores are 140 // considered. 141 // Label selection margin controls the difference between minimal input score 142 // (versus the best scoring label) for an item to be passed to the beam 143 // scorer. This margin is expressed in terms of log-probability. 144 // Default is to do no label selection. 145 // For more detail: https://research.google.com/pubs/pub44823.html 146 int label_selection_size_ = 0; // zero means unlimited 147 float label_selection_margin_ = -1; // -1 means unlimited. 148 149 gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_; 150 std::unique_ptr<BeamRoot> beam_root_; 151 BaseBeamScorer<CTCBeamState>* beam_scorer_; 152 153 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder); 154 }; 155 156 template <typename CTCBeamState, typename CTCBeamComparer> 157 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode( 158 const CTCDecoder::SequenceLength& seq_len, 159 const std::vector<CTCDecoder::Input>& input, 160 std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) { 161 // Storage for top paths. 162 std::vector<std::vector<int>> beams; 163 std::vector<float> beam_log_probabilities; 164 int top_n = output->size(); 165 if (std::any_of(output->begin(), output->end(), 166 [this](const CTCDecoder::Output& output) -> bool { 167 return output.size() < this->batch_size_; 168 })) { 169 return errors::InvalidArgument( 170 "output needs to be of size at least (top_n, batch_size)."); 171 } 172 if (scores->rows() < batch_size_ || scores->cols() < top_n) { 173 return errors::InvalidArgument( 174 "scores needs to be of size at least (batch_size, top_n)."); 175 } 176 177 for (int b = 0; b < batch_size_; ++b) { 178 int seq_len_b = seq_len[b]; 179 Reset(); 180 181 for (int t = 0; t < seq_len_b; ++t) { 182 // Pass log-probabilities for this example + time. 183 Step(input[t].row(b)); 184 } // for (int t... 185 186 // O(n * log(n)) 187 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract()); 188 leaves_.Reset(); 189 for (int i = 0; i < branches->size(); ++i) { 190 BeamEntry* entry = (*branches)[i]; 191 beam_scorer_->ExpandStateEnd(&entry->state); 192 entry->newp.total += 193 beam_scorer_->GetStateEndExpansionScore(entry->state); 194 leaves_.push(entry); 195 } 196 197 Status status = 198 TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); 199 if (!status.ok()) { 200 return status; 201 } 202 203 CHECK_EQ(top_n, beam_log_probabilities.size()); 204 CHECK_EQ(beams.size(), beam_log_probabilities.size()); 205 206 for (int i = 0; i < top_n; ++i) { 207 // Copy output to the correct beam + batch 208 (*output)[i][b].swap(beams[i]); 209 (*scores)(b, i) = -beam_log_probabilities[i]; 210 } 211 } // for (int b... 212 return Status::OK(); 213 } 214 215 template <typename CTCBeamState, typename CTCBeamComparer> 216 template <typename Vector> 217 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK( 218 const int K, const Vector& input, std::vector<float>* top_k_logits, 219 std::vector<int>* top_k_indices) { 220 // Find Top K choices, complexity nk in worst case. The array input is read 221 // just once. 222 CHECK_EQ(num_classes_, input.size()); 223 top_k_logits->clear(); 224 top_k_indices->clear(); 225 top_k_logits->resize(K, -INFINITY); 226 top_k_indices->resize(K, -1); 227 for (int j = 0; j < num_classes_ - 1; ++j) { 228 const float logit = input(j); 229 if (logit > (*top_k_logits)[K - 1]) { 230 int k = K - 1; 231 while (k > 0 && logit > (*top_k_logits)[k - 1]) { 232 (*top_k_logits)[k] = (*top_k_logits)[k - 1]; 233 (*top_k_indices)[k] = (*top_k_indices)[k - 1]; 234 k--; 235 } 236 (*top_k_logits)[k] = logit; 237 (*top_k_indices)[k] = j; 238 } 239 } 240 // Return max value which is in 0th index or blank character logit 241 return std::max((*top_k_logits)[0], input(num_classes_ - 1)); 242 } 243 244 template <typename CTCBeamState, typename CTCBeamComparer> 245 template <typename Vector> 246 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( 247 const Vector& raw_input) { 248 std::vector<float> top_k_logits; 249 std::vector<int> top_k_indices; 250 const bool top_k = 251 (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); 252 // Number of character classes to consider in each step. 253 const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); 254 // Get max coefficient and remove it from raw_input later. 255 float max_coeff; 256 if (top_k) { 257 max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, 258 &top_k_indices); 259 } else { 260 max_coeff = raw_input.maxCoeff(); 261 } 262 const float label_selection_input_min = 263 (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) 264 : -std::numeric_limits<float>::infinity(); 265 266 // Extract the beams sorted in decreasing new probability 267 CHECK_EQ(num_classes_, raw_input.size()); 268 269 std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract()); 270 leaves_.Reset(); 271 272 for (BeamEntry* b : *branches) { 273 // P(.. @ t) becomes the new P(.. @ t-1) 274 b->oldp = b->newp; 275 } 276 277 for (BeamEntry* b : *branches) { 278 if (b->parent != nullptr) { // if not the root 279 if (b->parent->Active()) { 280 // If last two sequence characters are identical: 281 // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) 282 // + Pblank(l=ac @ t=5)) 283 // else: 284 // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) 285 // + P(l=ab @ t=5)) 286 float previous = (b->label == b->parent->label) ? b->parent->oldp.blank 287 : b->parent->oldp.total; 288 b->newp.label = 289 LogSumExp(b->newp.label, 290 beam_scorer_->GetStateExpansionScore(b->state, previous)); 291 } 292 // Plabel(l=abc @ t=6) *= P(c @ 6) 293 b->newp.label += raw_input(b->label) - max_coeff; 294 } 295 // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) 296 b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; 297 // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) 298 b->newp.total = LogSumExp(b->newp.blank, b->newp.label); 299 300 // Push the entry back to the top paths list. 301 // Note, this will always fill leaves back up in sorted order. 302 leaves_.push(b); 303 } 304 305 // we need to resort branches in descending oldp order. 306 307 // branches is in descending oldp order because it was 308 // originally in descending newp order and we copied newp to oldp. 309 310 // Grow new leaves 311 for (BeamEntry* b : *branches) { 312 // A new leaf (represented by its BeamProbability) is a candidate 313 // iff its total probability is nonzero and either the beam list 314 // isn't full, or the lowest probability entry in the beam has a 315 // lower probability than the leaf. 316 auto is_candidate = [this](const BeamProbability& prob) { 317 return (prob.total > kLogZero && 318 (leaves_.size() < beam_width_ || 319 prob.total > leaves_.peek_bottom()->newp.total)); 320 }; 321 322 if (!is_candidate(b->oldp)) { 323 continue; 324 } 325 326 for (int ind = 0; ind < max_classes; ind++) { 327 const int label = top_k ? top_k_indices[ind] : ind; 328 const float logit = top_k ? top_k_logits[ind] : raw_input(ind); 329 // Perform label selection: if input for this label looks very 330 // unpromising, never evaluate it with a scorer. 331 if (logit < label_selection_input_min) { 332 continue; 333 } 334 BeamEntry& c = b->GetChild(label); 335 if (!c.Active()) { 336 // Pblank(l=abcd @ t=6) = 0 337 c.newp.blank = kLogZero; 338 // If new child label is identical to beam label: 339 // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) 340 // Otherwise: 341 // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) 342 beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); 343 float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; 344 c.newp.label = logit - max_coeff + 345 beam_scorer_->GetStateExpansionScore(c.state, previous); 346 // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) 347 c.newp.total = c.newp.label; 348 349 if (is_candidate(c.newp)) { 350 // Before adding the new node to the beam, check if the beam 351 // is already at maximum width. 352 if (leaves_.size() == beam_width_) { 353 // Bottom is no longer in the beam search. Reset 354 // its probability; signal it's no longer in the beam search. 355 BeamEntry* bottom = leaves_.peek_bottom(); 356 bottom->newp.Reset(); 357 } 358 leaves_.push(&c); 359 } else { 360 // Deactivate child. 361 c.oldp.Reset(); 362 c.newp.Reset(); 363 } 364 } 365 } 366 } // for (BeamEntry* b... 367 } 368 369 template <typename CTCBeamState, typename CTCBeamComparer> 370 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() { 371 leaves_.Reset(); 372 373 // This beam root, and all of its children, will be in memory until 374 // the next reset. 375 beam_root_.reset(new BeamRoot(nullptr, -1)); 376 beam_root_->RootEntry()->newp.total = 0.0; // ln(1) 377 beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) 378 379 // Add the root as the initial leaf. 380 leaves_.push(beam_root_->RootEntry()); 381 382 // Call initialize state on the root object. 383 beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); 384 } 385 386 template <typename CTCBeamState, typename CTCBeamComparer> 387 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths( 388 int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs, 389 bool merge_repeated) const { 390 CHECK_NOTNULL(paths)->clear(); 391 CHECK_NOTNULL(log_probs)->clear(); 392 if (n > beam_width_) { 393 return errors::InvalidArgument("requested more paths than the beam width."); 394 } 395 if (n > leaves_.size()) { 396 return errors::InvalidArgument( 397 "Less leaves in the beam search than requested."); 398 } 399 400 gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n); 401 402 // O(beam_width_ * log(n)), space complexity is O(n) 403 for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { 404 top_branches.push(*it); 405 } 406 // O(n * log(n)) 407 std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract()); 408 409 for (int i = 0; i < n; ++i) { 410 BeamEntry* e((*branches)[i]); 411 paths->push_back(e->LabelSeq(merge_repeated)); 412 log_probs->push_back(e->newp.total); 413 } 414 return Status::OK(); 415 } 416 417 } // namespace ctc 418 } // namespace tensorflow 419 420 #endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ 421