Home | History | Annotate | Download | only in ctc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
     17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
     18 
     19 #include <algorithm>
     20 #include <cmath>
     21 #include <limits>
     22 #include <memory>
     23 #include <vector>
     24 
     25 #include "third_party/eigen3/Eigen/Core"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/gtl/top_n.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/ctc/ctc_beam_entry.h"
     33 #include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
     34 #include "tensorflow/core/util/ctc/ctc_decoder.h"
     35 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
     36 
     37 namespace tensorflow {
     38 namespace ctc {
     39 
     40 template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
     41           typename CTCBeamComparer =
     42               ctc_beam_search::BeamComparer<CTCBeamState>>
     43 class CTCBeamSearchDecoder : public CTCDecoder {
     44   // Beam Search
     45   //
     46   // Example (GravesTh Fig. 7.5):
     47   //         a    -
     48   //  P = [ 0.3  0.7 ]  t = 0
     49   //      [ 0.4  0.6 ]  t = 1
     50   //
     51   // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
     52   //      P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
     53   //
     54   // In this case, Best Path decoding is suboptimal.
     55   //
     56   // For Beam Search, we use the following main recurrence relations:
     57   //
     58   // Relation 1:
     59   // ---------------------------------------------------------- Eq. 1
     60   //      P(l=abcd @ t=7) = P(l=abc  @ t=6) * P(d @ 7)
     61   //                      + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
     62   // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
     63   // updated recursively in the beam entry.
     64   //
     65   // Relation 2:
     66   // ---------------------------------------------------------- Eq. 2
     67   //      P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
     68   // for ? in a, b, d, ..., (not including c or the blank index),
     69   // and the recurrence starts from the beam entry for P(l=abc @ t=2).
     70   //
     71   // For this case, the length of the new sequence equals t+1 (t
     72   // starts at 0).  This special case can be calculated as:
     73   //   P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
     74   // but we calculate it recursively for speed purposes.
     75   typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
     76   typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
     77   typedef ctc_beam_search::BeamProbability BeamProbability;
     78 
     79  public:
     80   typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
     81 
     82   // The beam search decoder is constructed specifying the beam_width (number of
     83   // candidates to keep at each decoding timestep) and a beam scorer (used for
     84   // custom scoring, for example enabling the use of a language model).
     85   // The ownership of the scorer remains with the caller. The default
     86   // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
     87   // standard beam search.
     88   CTCBeamSearchDecoder(int num_classes, int beam_width,
     89                        BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
     90                        bool merge_repeated = false)
     91       : CTCDecoder(num_classes, batch_size, merge_repeated),
     92         beam_width_(beam_width),
     93         leaves_(beam_width),
     94         beam_scorer_(CHECK_NOTNULL(scorer)) {
     95     Reset();
     96   }
     97 
     98   ~CTCBeamSearchDecoder() override {}
     99 
    100   // Run the hibernating beam search algorithm on the given input.
    101   Status Decode(const CTCDecoder::SequenceLength& seq_len,
    102                 const std::vector<CTCDecoder::Input>& input,
    103                 std::vector<CTCDecoder::Output>* output,
    104                 CTCDecoder::ScoreOutput* scores) override;
    105 
    106   // Calculate the next step of the beam search and update the internal state.
    107   template <typename Vector>
    108   void Step(const Vector& log_input_t);
    109 
    110   template <typename Vector>
    111   float GetTopK(const int K, const Vector& input,
    112                 std::vector<float>* top_k_logits,
    113                 std::vector<int>* top_k_indices);
    114 
    115   // Retrieve the beam scorer instance used during decoding.
    116   BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
    117 
    118   // Set label selection parameters for faster decoding.
    119   // See comments for label_selection_size_ and label_selection_margin_.
    120   void SetLabelSelectionParameters(int label_selection_size,
    121                                    float label_selection_margin) {
    122     label_selection_size_ = label_selection_size;
    123     label_selection_margin_ = label_selection_margin;
    124   }
    125 
    126   // Reset the beam search
    127   void Reset();
    128 
    129   // Extract the top n paths at current time step
    130   Status TopPaths(int n, std::vector<std::vector<int>>* paths,
    131                   std::vector<float>* log_probs, bool merge_repeated) const;
    132 
    133  private:
    134   int beam_width_;
    135 
    136   // Label selection is designed to avoid possibly very expensive scorer calls,
    137   // by pruning the hypotheses based on the input alone.
    138   // Label selection size controls how many items in each beam are passed
    139   // through to the beam scorer. Only items with top N input scores are
    140   // considered.
    141   // Label selection margin controls the difference between minimal input score
    142   // (versus the best scoring label) for an item to be passed to the beam
    143   // scorer. This margin is expressed in terms of log-probability.
    144   // Default is to do no label selection.
    145   // For more detail: https://research.google.com/pubs/pub44823.html
    146   int label_selection_size_ = 0;       // zero means unlimited
    147   float label_selection_margin_ = -1;  // -1 means unlimited.
    148 
    149   gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
    150   std::unique_ptr<BeamRoot> beam_root_;
    151   BaseBeamScorer<CTCBeamState>* beam_scorer_;
    152 
    153   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
    154 };
    155 
    156 template <typename CTCBeamState, typename CTCBeamComparer>
    157 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
    158     const CTCDecoder::SequenceLength& seq_len,
    159     const std::vector<CTCDecoder::Input>& input,
    160     std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
    161   // Storage for top paths.
    162   std::vector<std::vector<int>> beams;
    163   std::vector<float> beam_log_probabilities;
    164   int top_n = output->size();
    165   if (std::any_of(output->begin(), output->end(),
    166                   [this](const CTCDecoder::Output& output) -> bool {
    167                     return output.size() < this->batch_size_;
    168                   })) {
    169     return errors::InvalidArgument(
    170         "output needs to be of size at least (top_n, batch_size).");
    171   }
    172   if (scores->rows() < batch_size_ || scores->cols() < top_n) {
    173     return errors::InvalidArgument(
    174         "scores needs to be of size at least (batch_size, top_n).");
    175   }
    176 
    177   for (int b = 0; b < batch_size_; ++b) {
    178     int seq_len_b = seq_len[b];
    179     Reset();
    180 
    181     for (int t = 0; t < seq_len_b; ++t) {
    182       // Pass log-probabilities for this example + time.
    183       Step(input[t].row(b));
    184     }  // for (int t...
    185 
    186     // O(n * log(n))
    187     std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
    188     leaves_.Reset();
    189     for (int i = 0; i < branches->size(); ++i) {
    190       BeamEntry* entry = (*branches)[i];
    191       beam_scorer_->ExpandStateEnd(&entry->state);
    192       entry->newp.total +=
    193           beam_scorer_->GetStateEndExpansionScore(entry->state);
    194       leaves_.push(entry);
    195     }
    196 
    197     Status status =
    198         TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
    199     if (!status.ok()) {
    200       return status;
    201     }
    202 
    203     CHECK_EQ(top_n, beam_log_probabilities.size());
    204     CHECK_EQ(beams.size(), beam_log_probabilities.size());
    205 
    206     for (int i = 0; i < top_n; ++i) {
    207       // Copy output to the correct beam + batch
    208       (*output)[i][b].swap(beams[i]);
    209       (*scores)(b, i) = -beam_log_probabilities[i];
    210     }
    211   }  // for (int b...
    212   return Status::OK();
    213 }
    214 
    215 template <typename CTCBeamState, typename CTCBeamComparer>
    216 template <typename Vector>
    217 float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
    218     const int K, const Vector& input, std::vector<float>* top_k_logits,
    219     std::vector<int>* top_k_indices) {
    220   // Find Top K choices, complexity nk in worst case. The array input is read
    221   // just once.
    222   CHECK_EQ(num_classes_, input.size());
    223   top_k_logits->clear();
    224   top_k_indices->clear();
    225   top_k_logits->resize(K, -INFINITY);
    226   top_k_indices->resize(K, -1);
    227   for (int j = 0; j < num_classes_ - 1; ++j) {
    228     const float logit = input(j);
    229     if (logit > (*top_k_logits)[K - 1]) {
    230       int k = K - 1;
    231       while (k > 0 && logit > (*top_k_logits)[k - 1]) {
    232         (*top_k_logits)[k] = (*top_k_logits)[k - 1];
    233         (*top_k_indices)[k] = (*top_k_indices)[k - 1];
    234         k--;
    235       }
    236       (*top_k_logits)[k] = logit;
    237       (*top_k_indices)[k] = j;
    238     }
    239   }
    240   // Return max value which is in 0th index or blank character logit
    241   return std::max((*top_k_logits)[0], input(num_classes_ - 1));
    242 }
    243 
    244 template <typename CTCBeamState, typename CTCBeamComparer>
    245 template <typename Vector>
    246 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
    247     const Vector& raw_input) {
    248   std::vector<float> top_k_logits;
    249   std::vector<int> top_k_indices;
    250   const bool top_k =
    251       (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
    252   // Number of character classes to consider in each step.
    253   const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
    254   // Get max coefficient and remove it from raw_input later.
    255   float max_coeff;
    256   if (top_k) {
    257     max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
    258                         &top_k_indices);
    259   } else {
    260     max_coeff = raw_input.maxCoeff();
    261   }
    262   const float label_selection_input_min =
    263       (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
    264                                      : -std::numeric_limits<float>::infinity();
    265 
    266   // Extract the beams sorted in decreasing new probability
    267   CHECK_EQ(num_classes_, raw_input.size());
    268 
    269   std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
    270   leaves_.Reset();
    271 
    272   for (BeamEntry* b : *branches) {
    273     // P(.. @ t) becomes the new P(.. @ t-1)
    274     b->oldp = b->newp;
    275   }
    276 
    277   for (BeamEntry* b : *branches) {
    278     if (b->parent != nullptr) {  // if not the root
    279       if (b->parent->Active()) {
    280         // If last two sequence characters are identical:
    281         //   Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
    282         //                          + Pblank(l=ac @ t=5))
    283         // else:
    284         //   Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
    285         //                          + P(l=ab @ t=5))
    286         float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
    287                                                         : b->parent->oldp.total;
    288         b->newp.label =
    289             LogSumExp(b->newp.label,
    290                       beam_scorer_->GetStateExpansionScore(b->state, previous));
    291       }
    292       // Plabel(l=abc @ t=6) *= P(c @ 6)
    293       b->newp.label += raw_input(b->label) - max_coeff;
    294     }
    295     // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
    296     b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
    297     // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
    298     b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
    299 
    300     // Push the entry back to the top paths list.
    301     // Note, this will always fill leaves back up in sorted order.
    302     leaves_.push(b);
    303   }
    304 
    305   // we need to resort branches in descending oldp order.
    306 
    307   // branches is in descending oldp order because it was
    308   // originally in descending newp order and we copied newp to oldp.
    309 
    310   // Grow new leaves
    311   for (BeamEntry* b : *branches) {
    312     // A new leaf (represented by its BeamProbability) is a candidate
    313     // iff its total probability is nonzero and either the beam list
    314     // isn't full, or the lowest probability entry in the beam has a
    315     // lower probability than the leaf.
    316     auto is_candidate = [this](const BeamProbability& prob) {
    317       return (prob.total > kLogZero &&
    318               (leaves_.size() < beam_width_ ||
    319                prob.total > leaves_.peek_bottom()->newp.total));
    320     };
    321 
    322     if (!is_candidate(b->oldp)) {
    323       continue;
    324     }
    325 
    326     for (int ind = 0; ind < max_classes; ind++) {
    327       const int label = top_k ? top_k_indices[ind] : ind;
    328       const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
    329       // Perform label selection: if input for this label looks very
    330       // unpromising, never evaluate it with a scorer.
    331       if (logit < label_selection_input_min) {
    332         continue;
    333       }
    334       BeamEntry& c = b->GetChild(label);
    335       if (!c.Active()) {
    336         //   Pblank(l=abcd @ t=6) = 0
    337         c.newp.blank = kLogZero;
    338         // If new child label is identical to beam label:
    339         //   Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
    340         // Otherwise:
    341         //   Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
    342         beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
    343         float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
    344         c.newp.label = logit - max_coeff +
    345                        beam_scorer_->GetStateExpansionScore(c.state, previous);
    346         // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
    347         c.newp.total = c.newp.label;
    348 
    349         if (is_candidate(c.newp)) {
    350           // Before adding the new node to the beam, check if the beam
    351           // is already at maximum width.
    352           if (leaves_.size() == beam_width_) {
    353             // Bottom is no longer in the beam search.  Reset
    354             // its probability; signal it's no longer in the beam search.
    355             BeamEntry* bottom = leaves_.peek_bottom();
    356             bottom->newp.Reset();
    357           }
    358           leaves_.push(&c);
    359         } else {
    360           // Deactivate child.
    361           c.oldp.Reset();
    362           c.newp.Reset();
    363         }
    364       }
    365     }
    366   }  // for (BeamEntry* b...
    367 }
    368 
    369 template <typename CTCBeamState, typename CTCBeamComparer>
    370 void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
    371   leaves_.Reset();
    372 
    373   // This beam root, and all of its children, will be in memory until
    374   // the next reset.
    375   beam_root_.reset(new BeamRoot(nullptr, -1));
    376   beam_root_->RootEntry()->newp.total = 0.0;  // ln(1)
    377   beam_root_->RootEntry()->newp.blank = 0.0;  // ln(1)
    378 
    379   // Add the root as the initial leaf.
    380   leaves_.push(beam_root_->RootEntry());
    381 
    382   // Call initialize state on the root object.
    383   beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
    384 }
    385 
    386 template <typename CTCBeamState, typename CTCBeamComparer>
    387 Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
    388     int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
    389     bool merge_repeated) const {
    390   CHECK_NOTNULL(paths)->clear();
    391   CHECK_NOTNULL(log_probs)->clear();
    392   if (n > beam_width_) {
    393     return errors::InvalidArgument("requested more paths than the beam width.");
    394   }
    395   if (n > leaves_.size()) {
    396     return errors::InvalidArgument(
    397         "Less leaves in the beam search than requested.");
    398   }
    399 
    400   gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
    401 
    402   // O(beam_width_ * log(n)), space complexity is O(n)
    403   for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
    404     top_branches.push(*it);
    405   }
    406   // O(n * log(n))
    407   std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
    408 
    409   for (int i = 0; i < n; ++i) {
    410     BeamEntry* e((*branches)[i]);
    411     paths->push_back(e->LabelSeq(merge_repeated));
    412     log_probs->push_back(e->newp.total);
    413   }
    414   return Status::OK();
    415 }
    416 
    417 }  // namespace ctc
    418 }  // namespace tensorflow
    419 
    420 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
    421