Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/ctc_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <limits>
     21 
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/util/ctc/ctc_beam_search.h"
     29 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     30 
     31 namespace tensorflow {
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 
     35 inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
     36                     int* c) {
     37   *c = 0;
     38   CHECK_LT(0, m.dimension(1));
     39   float p = m(r, 0);
     40   for (int i = 1; i < m.dimension(1); ++i) {
     41     if (m(r, i) > p) {
     42       p = m(r, i);
     43       *c = i;
     44     }
     45   }
     46   return p;
     47 }
     48 
     49 class CTCDecodeHelper {
     50  public:
     51   CTCDecodeHelper() : top_paths_(1) {}
     52 
     53   inline int GetTopPaths() const { return top_paths_; }
     54   void SetTopPaths(int tp) { top_paths_ = tp; }
     55 
     56   Status ValidateInputsGenerateOutputs(
     57       OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
     58       Tensor** log_prob, OpOutputList* decoded_indices,
     59       OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
     60     Status status = ctx->input("inputs", inputs);
     61     if (!status.ok()) return status;
     62     status = ctx->input("sequence_length", seq_len);
     63     if (!status.ok()) return status;
     64 
     65     const TensorShape& inputs_shape = (*inputs)->shape();
     66 
     67     if (inputs_shape.dims() != 3) {
     68       return errors::InvalidArgument("inputs is not a 3-Tensor");
     69     }
     70 
     71     const int64 max_time = inputs_shape.dim_size(0);
     72     const int64 batch_size = inputs_shape.dim_size(1);
     73 
     74     if (max_time == 0) {
     75       return errors::InvalidArgument("max_time is 0");
     76     }
     77     if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
     78       return errors::InvalidArgument("sequence_length is not a vector");
     79     }
     80 
     81     if (!(batch_size == (*seq_len)->dim_size(0))) {
     82       return errors::FailedPrecondition(
     83           "len(sequence_length) != batch_size.  ",
     84           "len(sequence_length):  ", (*seq_len)->dim_size(0),
     85           " batch_size: ", batch_size);
     86     }
     87 
     88     auto seq_len_t = (*seq_len)->vec<int32>();
     89 
     90     for (int b = 0; b < batch_size; ++b) {
     91       if (!(seq_len_t(b) <= max_time)) {
     92         return errors::FailedPrecondition("sequence_length(", b,
     93                                           ") <= ", max_time);
     94       }
     95     }
     96 
     97     Status s = ctx->allocate_output(
     98         "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
     99     if (!s.ok()) return s;
    100 
    101     s = ctx->output_list("decoded_indices", decoded_indices);
    102     if (!s.ok()) return s;
    103     s = ctx->output_list("decoded_values", decoded_values);
    104     if (!s.ok()) return s;
    105     s = ctx->output_list("decoded_shape", decoded_shape);
    106     if (!s.ok()) return s;
    107 
    108     return Status::OK();
    109   }
    110 
    111   // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
    112   Status StoreAllDecodedSequences(
    113       const std::vector<std::vector<std::vector<int> > >& sequences,
    114       OpOutputList* decoded_indices, OpOutputList* decoded_values,
    115       OpOutputList* decoded_shape) const {
    116     // Calculate the total number of entries for each path
    117     const int64 batch_size = sequences.size();
    118     std::vector<int64> num_entries(top_paths_, 0);
    119 
    120     // Calculate num_entries per path
    121     for (const auto& batch_s : sequences) {
    122       CHECK_EQ(batch_s.size(), top_paths_);
    123       for (int p = 0; p < top_paths_; ++p) {
    124         num_entries[p] += batch_s[p].size();
    125       }
    126     }
    127 
    128     for (int p = 0; p < top_paths_; ++p) {
    129       Tensor* p_indices = nullptr;
    130       Tensor* p_values = nullptr;
    131       Tensor* p_shape = nullptr;
    132 
    133       const int64 p_num = num_entries[p];
    134 
    135       Status s =
    136           decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
    137       if (!s.ok()) return s;
    138       s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
    139       if (!s.ok()) return s;
    140       s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
    141       if (!s.ok()) return s;
    142 
    143       auto indices_t = p_indices->matrix<int64>();
    144       auto values_t = p_values->vec<int64>();
    145       auto shape_t = p_shape->vec<int64>();
    146 
    147       int64 max_decoded = 0;
    148       int64 offset = 0;
    149 
    150       for (int64 b = 0; b < batch_size; ++b) {
    151         auto& p_batch = sequences[b][p];
    152         int64 num_decoded = p_batch.size();
    153         max_decoded = std::max(max_decoded, num_decoded);
    154         std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
    155         for (int64 t = 0; t < num_decoded; ++t, ++offset) {
    156           indices_t(offset, 0) = b;
    157           indices_t(offset, 1) = t;
    158         }
    159       }
    160 
    161       shape_t(0) = batch_size;
    162       shape_t(1) = max_decoded;
    163     }
    164     return Status::OK();
    165   }
    166 
    167  private:
    168   int top_paths_;
    169   TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
    170 };
    171 
    172 class CTCGreedyDecoderOp : public OpKernel {
    173  public:
    174   explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    175     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
    176   }
    177 
    178   void Compute(OpKernelContext* ctx) override {
    179     const Tensor* inputs;
    180     const Tensor* seq_len;
    181     Tensor* log_prob = nullptr;
    182     OpOutputList decoded_indices;
    183     OpOutputList decoded_values;
    184     OpOutputList decoded_shape;
    185     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
    186                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
    187                             &decoded_values, &decoded_shape));
    188 
    189     const TensorShape& inputs_shape = inputs->shape();
    190 
    191     std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
    192     const int64 max_time = inputs_shape.dim_size(0);
    193     const int64 batch_size = inputs_shape.dim_size(1);
    194     const int64 num_classes_raw = inputs_shape.dim_size(2);
    195     OP_REQUIRES(
    196         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    197         errors::InvalidArgument("num_classes cannot exceed max int"));
    198     const int num_classes = static_cast<const int>(num_classes_raw);
    199 
    200     auto inputs_t = inputs->tensor<float, 3>();
    201 
    202     for (std::size_t t = 0; t < max_time; ++t) {
    203       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
    204                                 batch_size, num_classes);
    205     }
    206     auto seq_len_t = seq_len->vec<int32>();
    207     auto log_prob_t = log_prob->matrix<float>();
    208 
    209     log_prob_t.setZero();
    210 
    211     // Assumption: the blank index is num_classes - 1
    212     int blank_index = num_classes - 1;
    213 
    214     // Perform best path decoding
    215     std::vector<std::vector<std::vector<int> > > sequences(batch_size);
    216     for (int b = 0; b < batch_size; ++b) {
    217       sequences[b].resize(1);
    218       auto& sequence = sequences[b][0];
    219       int prev_indices = -1;
    220       for (int t = 0; t < seq_len_t(b); ++t) {
    221         int max_class_indices;
    222         log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
    223         if (max_class_indices != blank_index &&
    224             !(merge_repeated_ && max_class_indices == prev_indices)) {
    225           sequence.push_back(max_class_indices);
    226         }
    227         prev_indices = max_class_indices;
    228       }
    229     }
    230 
    231     OP_REQUIRES_OK(
    232         ctx, decode_helper_.StoreAllDecodedSequences(
    233                  sequences, &decoded_indices, &decoded_values, &decoded_shape));
    234   }
    235 
    236  private:
    237   CTCDecodeHelper decode_helper_;
    238   bool merge_repeated_;
    239 
    240   TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
    241 };
    242 
    243 REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
    244                         CTCGreedyDecoderOp);
    245 
    246 // CTC beam search
    247 class CTCBeamSearchDecoderOp : public OpKernel {
    248  public:
    249   explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    250     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
    251     OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
    252     int top_paths;
    253     OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
    254     decode_helper_.SetTopPaths(top_paths);
    255   }
    256 
    257   void Compute(OpKernelContext* ctx) override {
    258     const Tensor* inputs;
    259     const Tensor* seq_len;
    260     Tensor* log_prob = nullptr;
    261     OpOutputList decoded_indices;
    262     OpOutputList decoded_values;
    263     OpOutputList decoded_shape;
    264     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
    265                             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
    266                             &decoded_values, &decoded_shape));
    267 
    268     auto inputs_t = inputs->tensor<float, 3>();
    269     auto seq_len_t = seq_len->vec<int32>();
    270     auto log_prob_t = log_prob->matrix<float>();
    271 
    272     const TensorShape& inputs_shape = inputs->shape();
    273 
    274     const int64 max_time = inputs_shape.dim_size(0);
    275     const int64 batch_size = inputs_shape.dim_size(1);
    276     const int64 num_classes_raw = inputs_shape.dim_size(2);
    277     OP_REQUIRES(
    278         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    279         errors::InvalidArgument("num_classes cannot exceed max int"));
    280     const int num_classes = static_cast<const int>(num_classes_raw);
    281 
    282     log_prob_t.setZero();
    283 
    284     std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
    285 
    286     for (std::size_t t = 0; t < max_time; ++t) {
    287       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
    288                                 batch_size, num_classes);
    289     }
    290 
    291     ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
    292                                             &beam_scorer_, 1 /* batch_size */,
    293                                             merge_repeated_);
    294     Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
    295     auto input_chip_t = input_chip.flat<float>();
    296 
    297     std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
    298     std::vector<float> log_probs;
    299 
    300     // Assumption: the blank index is num_classes - 1
    301     for (int b = 0; b < batch_size; ++b) {
    302       auto& best_paths_b = best_paths[b];
    303       best_paths_b.resize(decode_helper_.GetTopPaths());
    304       for (int t = 0; t < seq_len_t(b); ++t) {
    305         input_chip_t = input_list_t[t].chip(b, 0);
    306         auto input_bi =
    307             Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
    308         beam_search.Step(input_bi);
    309       }
    310       OP_REQUIRES_OK(
    311           ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
    312                                     &log_probs, merge_repeated_));
    313 
    314       beam_search.Reset();
    315 
    316       for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    317         log_prob_t(b, bp) = log_probs[bp];
    318       }
    319     }
    320 
    321     OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
    322                             best_paths, &decoded_indices, &decoded_values,
    323                             &decoded_shape));
    324   }
    325 
    326  private:
    327   CTCDecodeHelper decode_helper_;
    328   ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
    329   bool merge_repeated_;
    330   int beam_width_;
    331   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
    332 };
    333 
    334 REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
    335                         CTCBeamSearchDecoderOp);
    336 
    337 }  // end namespace tensorflow
    338