1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/ctc_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <limits> 21 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/util/ctc/ctc_beam_search.h" 29 #include "tensorflow/core/util/sparse/sparse_tensor.h" 30 31 namespace tensorflow { 32 33 typedef Eigen::ThreadPoolDevice CPUDevice; 34 35 inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r, 36 int* c) { 37 *c = 0; 38 CHECK_LT(0, m.dimension(1)); 39 float p = m(r, 0); 40 for (int i = 1; i < m.dimension(1); ++i) { 41 if (m(r, i) > p) { 42 p = m(r, i); 43 *c = i; 44 } 45 } 46 return p; 47 } 48 49 class CTCDecodeHelper { 50 public: 51 CTCDecodeHelper() : top_paths_(1) {} 52 53 inline int GetTopPaths() const { return top_paths_; } 54 void SetTopPaths(int tp) { top_paths_ = tp; } 55 56 Status ValidateInputsGenerateOutputs( 57 OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len, 58 Tensor** log_prob, OpOutputList* decoded_indices, 59 OpOutputList* decoded_values, OpOutputList* decoded_shape) const { 60 Status status = ctx->input("inputs", inputs); 61 if (!status.ok()) return status; 62 status = ctx->input("sequence_length", seq_len); 63 if (!status.ok()) return status; 64 65 const TensorShape& inputs_shape = (*inputs)->shape(); 66 67 if (inputs_shape.dims() != 3) { 68 return errors::InvalidArgument("inputs is not a 3-Tensor"); 69 } 70 71 const int64 max_time = inputs_shape.dim_size(0); 72 const int64 batch_size = inputs_shape.dim_size(1); 73 74 if (max_time == 0) { 75 return errors::InvalidArgument("max_time is 0"); 76 } 77 if (!TensorShapeUtils::IsVector((*seq_len)->shape())) { 78 return errors::InvalidArgument("sequence_length is not a vector"); 79 } 80 81 if (!(batch_size == (*seq_len)->dim_size(0))) { 82 return errors::FailedPrecondition( 83 "len(sequence_length) != batch_size. ", 84 "len(sequence_length): ", (*seq_len)->dim_size(0), 85 " batch_size: ", batch_size); 86 } 87 88 auto seq_len_t = (*seq_len)->vec<int32>(); 89 90 for (int b = 0; b < batch_size; ++b) { 91 if (!(seq_len_t(b) <= max_time)) { 92 return errors::FailedPrecondition("sequence_length(", b, 93 ") <= ", max_time); 94 } 95 } 96 97 Status s = ctx->allocate_output( 98 "log_probability", TensorShape({batch_size, top_paths_}), log_prob); 99 if (!s.ok()) return s; 100 101 s = ctx->output_list("decoded_indices", decoded_indices); 102 if (!s.ok()) return s; 103 s = ctx->output_list("decoded_values", decoded_values); 104 if (!s.ok()) return s; 105 s = ctx->output_list("decoded_shape", decoded_shape); 106 if (!s.ok()) return s; 107 108 return Status::OK(); 109 } 110 111 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". 112 Status StoreAllDecodedSequences( 113 const std::vector<std::vector<std::vector<int> > >& sequences, 114 OpOutputList* decoded_indices, OpOutputList* decoded_values, 115 OpOutputList* decoded_shape) const { 116 // Calculate the total number of entries for each path 117 const int64 batch_size = sequences.size(); 118 std::vector<int64> num_entries(top_paths_, 0); 119 120 // Calculate num_entries per path 121 for (const auto& batch_s : sequences) { 122 CHECK_EQ(batch_s.size(), top_paths_); 123 for (int p = 0; p < top_paths_; ++p) { 124 num_entries[p] += batch_s[p].size(); 125 } 126 } 127 128 for (int p = 0; p < top_paths_; ++p) { 129 Tensor* p_indices = nullptr; 130 Tensor* p_values = nullptr; 131 Tensor* p_shape = nullptr; 132 133 const int64 p_num = num_entries[p]; 134 135 Status s = 136 decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices); 137 if (!s.ok()) return s; 138 s = decoded_values->allocate(p, TensorShape({p_num}), &p_values); 139 if (!s.ok()) return s; 140 s = decoded_shape->allocate(p, TensorShape({2}), &p_shape); 141 if (!s.ok()) return s; 142 143 auto indices_t = p_indices->matrix<int64>(); 144 auto values_t = p_values->vec<int64>(); 145 auto shape_t = p_shape->vec<int64>(); 146 147 int64 max_decoded = 0; 148 int64 offset = 0; 149 150 for (int64 b = 0; b < batch_size; ++b) { 151 auto& p_batch = sequences[b][p]; 152 int64 num_decoded = p_batch.size(); 153 max_decoded = std::max(max_decoded, num_decoded); 154 std::copy_n(p_batch.begin(), num_decoded, &values_t(offset)); 155 for (int64 t = 0; t < num_decoded; ++t, ++offset) { 156 indices_t(offset, 0) = b; 157 indices_t(offset, 1) = t; 158 } 159 } 160 161 shape_t(0) = batch_size; 162 shape_t(1) = max_decoded; 163 } 164 return Status::OK(); 165 } 166 167 private: 168 int top_paths_; 169 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper); 170 }; 171 172 class CTCGreedyDecoderOp : public OpKernel { 173 public: 174 explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 175 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_)); 176 } 177 178 void Compute(OpKernelContext* ctx) override { 179 const Tensor* inputs; 180 const Tensor* seq_len; 181 Tensor* log_prob = nullptr; 182 OpOutputList decoded_indices; 183 OpOutputList decoded_values; 184 OpOutputList decoded_shape; 185 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( 186 ctx, &inputs, &seq_len, &log_prob, &decoded_indices, 187 &decoded_values, &decoded_shape)); 188 189 const TensorShape& inputs_shape = inputs->shape(); 190 191 std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t; 192 const int64 max_time = inputs_shape.dim_size(0); 193 const int64 batch_size = inputs_shape.dim_size(1); 194 const int64 num_classes_raw = inputs_shape.dim_size(2); 195 OP_REQUIRES( 196 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), 197 errors::InvalidArgument("num_classes cannot exceed max int")); 198 const int num_classes = static_cast<const int>(num_classes_raw); 199 200 auto inputs_t = inputs->tensor<float, 3>(); 201 202 for (std::size_t t = 0; t < max_time; ++t) { 203 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, 204 batch_size, num_classes); 205 } 206 auto seq_len_t = seq_len->vec<int32>(); 207 auto log_prob_t = log_prob->matrix<float>(); 208 209 log_prob_t.setZero(); 210 211 // Assumption: the blank index is num_classes - 1 212 int blank_index = num_classes - 1; 213 214 // Perform best path decoding 215 std::vector<std::vector<std::vector<int> > > sequences(batch_size); 216 for (int b = 0; b < batch_size; ++b) { 217 sequences[b].resize(1); 218 auto& sequence = sequences[b][0]; 219 int prev_indices = -1; 220 for (int t = 0; t < seq_len_t(b); ++t) { 221 int max_class_indices; 222 log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); 223 if (max_class_indices != blank_index && 224 !(merge_repeated_ && max_class_indices == prev_indices)) { 225 sequence.push_back(max_class_indices); 226 } 227 prev_indices = max_class_indices; 228 } 229 } 230 231 OP_REQUIRES_OK( 232 ctx, decode_helper_.StoreAllDecodedSequences( 233 sequences, &decoded_indices, &decoded_values, &decoded_shape)); 234 } 235 236 private: 237 CTCDecodeHelper decode_helper_; 238 bool merge_repeated_; 239 240 TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp); 241 }; 242 243 REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU), 244 CTCGreedyDecoderOp); 245 246 // CTC beam search 247 class CTCBeamSearchDecoderOp : public OpKernel { 248 public: 249 explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 250 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_)); 251 OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_)); 252 int top_paths; 253 OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); 254 decode_helper_.SetTopPaths(top_paths); 255 } 256 257 void Compute(OpKernelContext* ctx) override { 258 const Tensor* inputs; 259 const Tensor* seq_len; 260 Tensor* log_prob = nullptr; 261 OpOutputList decoded_indices; 262 OpOutputList decoded_values; 263 OpOutputList decoded_shape; 264 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( 265 ctx, &inputs, &seq_len, &log_prob, &decoded_indices, 266 &decoded_values, &decoded_shape)); 267 268 auto inputs_t = inputs->tensor<float, 3>(); 269 auto seq_len_t = seq_len->vec<int32>(); 270 auto log_prob_t = log_prob->matrix<float>(); 271 272 const TensorShape& inputs_shape = inputs->shape(); 273 274 const int64 max_time = inputs_shape.dim_size(0); 275 const int64 batch_size = inputs_shape.dim_size(1); 276 const int64 num_classes_raw = inputs_shape.dim_size(2); 277 OP_REQUIRES( 278 ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), 279 errors::InvalidArgument("num_classes cannot exceed max int")); 280 const int num_classes = static_cast<const int>(num_classes_raw); 281 282 log_prob_t.setZero(); 283 284 std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t; 285 286 for (std::size_t t = 0; t < max_time; ++t) { 287 input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, 288 batch_size, num_classes); 289 } 290 291 ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_, 292 &beam_scorer_, 1 /* batch_size */, 293 merge_repeated_); 294 Tensor input_chip(DT_FLOAT, TensorShape({num_classes})); 295 auto input_chip_t = input_chip.flat<float>(); 296 297 std::vector<std::vector<std::vector<int> > > best_paths(batch_size); 298 std::vector<float> log_probs; 299 300 // Assumption: the blank index is num_classes - 1 301 for (int b = 0; b < batch_size; ++b) { 302 auto& best_paths_b = best_paths[b]; 303 best_paths_b.resize(decode_helper_.GetTopPaths()); 304 for (int t = 0; t < seq_len_t(b); ++t) { 305 input_chip_t = input_list_t[t].chip(b, 0); 306 auto input_bi = 307 Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes); 308 beam_search.Step(input_bi); 309 } 310 OP_REQUIRES_OK( 311 ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b, 312 &log_probs, merge_repeated_)); 313 314 beam_search.Reset(); 315 316 for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) { 317 log_prob_t(b, bp) = log_probs[bp]; 318 } 319 } 320 321 OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences( 322 best_paths, &decoded_indices, &decoded_values, 323 &decoded_shape)); 324 } 325 326 private: 327 CTCDecodeHelper decode_helper_; 328 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_; 329 bool merge_repeated_; 330 int beam_width_; 331 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp); 332 }; 333 334 REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU), 335 CTCBeamSearchDecoderOp); 336 337 } // end namespace tensorflow 338