Home | History | Annotate | Download | only in ctc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
     17 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
     18 
     19 #include <algorithm>
     20 #include <memory>
     21 #include <vector>
     22 
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "tensorflow/core/lib/gtl/flatmap.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 #include "tensorflow/core/platform/macros.h"
     27 #include "tensorflow/core/platform/types.h"
     28 #include "tensorflow/core/util/ctc/ctc_loss_util.h"
     29 
     30 namespace tensorflow {
     31 namespace ctc {
     32 
     33 // The ctc_beam_search namespace holds several classes meant to be accessed only
     34 // in case of extending the CTCBeamSearch decoder to allow custom scoring
     35 // functions.
     36 //
     37 // BeamEntry is exposed through template arguments BeamScorer and BeamComparer
     38 // of CTCBeamSearch (ctc_beam_search.h).
     39 namespace ctc_beam_search {
     40 
     41 struct EmptyBeamState {};
     42 
     43 struct BeamProbability {
     44   BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
     45   void Reset() {
     46     total = kLogZero;
     47     blank = kLogZero;
     48     label = kLogZero;
     49   }
     50   float total;
     51   float blank;
     52   float label;
     53 };
     54 
     55 template <class CTCBeamState>
     56 class BeamRoot;
     57 
     58 template <class CTCBeamState = EmptyBeamState>
     59 struct BeamEntry {
     60   // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
     61   friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
     62       BeamEntry<CTCBeamState>* p, int l);
     63   inline bool Active() const { return newp.total != kLogZero; }
     64   // Return the child at the given index, or construct a new one in-place if
     65   // none was found.
     66   BeamEntry& GetChild(int ind) {
     67     auto entry = children.emplace(ind, nullptr);
     68     auto& child_entry = entry.first->second;
     69     // If this is a new child, populate the BeamEntry<CTCBeamState>*.
     70     if (entry.second) {
     71       child_entry = beam_root->AddEntry(this, ind);
     72     }
     73     return *child_entry;
     74   }
     75   std::vector<int> LabelSeq(bool merge_repeated) const {
     76     std::vector<int> labels;
     77     int prev_label = -1;
     78     const BeamEntry* c = this;
     79     while (c->parent != nullptr) {  // Checking c->parent to skip root leaf.
     80       if (!merge_repeated || c->label != prev_label) {
     81         labels.push_back(c->label);
     82       }
     83       prev_label = c->label;
     84       c = c->parent;
     85     }
     86     std::reverse(labels.begin(), labels.end());
     87     return labels;
     88   }
     89 
     90   BeamEntry<CTCBeamState>* parent;
     91   int label;
     92   // All instances of child BeamEntry are owned by *beam_root.
     93   gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
     94   BeamProbability oldp;
     95   BeamProbability newp;
     96   CTCBeamState state;
     97 
     98  private:
     99   // Constructor giving parent, label, and the beam_root.
    100   // The object pointed to by p cannot be copied and should not be moved,
    101   // otherwise parent will become invalid.
    102   // This private constructor is only called through the factory method
    103   // BeamRoot<CTCBeamState>::AddEntry().
    104   BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
    105       : parent(p), label(l), beam_root(beam_root) {}
    106   BeamRoot<CTCBeamState>* beam_root;
    107   TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
    108 };
    109 
    110 // This class owns all instances of BeamEntry.  This is used to avoid recursive
    111 // destructor call during destruction.
    112 template <class CTCBeamState = EmptyBeamState>
    113 class BeamRoot {
    114  public:
    115   BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
    116   BeamRoot(const BeamRoot&) = delete;
    117   BeamRoot& operator=(const BeamRoot&) = delete;
    118 
    119   BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
    120     auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
    121     beam_entries_.emplace_back(new_entry);
    122     return new_entry;
    123   }
    124   BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
    125 
    126  private:
    127   BeamEntry<CTCBeamState>* root_entry_ = nullptr;
    128   std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
    129 };
    130 
    131 // BeamComparer is the default beam comparer provided in CTCBeamSearch.
    132 template <class CTCBeamState = EmptyBeamState>
    133 class BeamComparer {
    134  public:
    135   virtual ~BeamComparer() {}
    136   virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
    137                                  const BeamEntry<CTCBeamState>* b) const {
    138     return a->newp.total > b->newp.total;
    139   }
    140 };
    141 
    142 }  // namespace ctc_beam_search
    143 
    144 }  // namespace ctc
    145 }  // namespace tensorflow
    146 
    147 #endif  // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
    148