Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/ctc_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/types.h"
     21 #include "tensorflow/core/kernels/bounds_check.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
     25 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 
     31 class CTCLossOp : public OpKernel {
     32   typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
     33                                          Eigen::RowMajor> >
     34       InputMap;
     35   typedef Eigen::Map<
     36       Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
     37       OutputMap;
     38 
     39  public:
     40   explicit CTCLossOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     41     OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
     42                                      &preprocess_collapse_repeated_));
     43     OP_REQUIRES_OK(ctx,
     44                    ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
     45     OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
     46                                      &ignore_longer_outputs_than_inputs_));
     47   }
     48 
     49   void Compute(OpKernelContext* ctx) override {
     50     const Tensor* inputs;
     51     const Tensor* labels_indices;
     52     const Tensor* labels_values;
     53     const Tensor* seq_len;
     54     OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
     55     OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
     56     OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
     57     OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
     58 
     59     OP_REQUIRES(ctx, inputs->shape().dims() == 3,
     60                 errors::InvalidArgument("inputs is not a 3-Tensor"));
     61     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
     62                 errors::InvalidArgument("sequence_length is not a vector"));
     63     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
     64                 errors::InvalidArgument("labels_indices is not a matrix"));
     65     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
     66                 errors::InvalidArgument("labels_values is not a vector"));
     67 
     68     const TensorShape& inputs_shape = inputs->shape();
     69     const int64 max_time = inputs_shape.dim_size(0);
     70     const int64 batch_size = inputs_shape.dim_size(1);
     71     const int64 num_classes_raw = inputs_shape.dim_size(2);
     72     OP_REQUIRES(
     73         ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
     74         errors::InvalidArgument("num_classes cannot exceed max int"));
     75     const int num_classes = static_cast<const int>(num_classes_raw);
     76 
     77     OP_REQUIRES(
     78         ctx, batch_size == seq_len->dim_size(0),
     79         errors::InvalidArgument("len(sequence_length) != batch_size.  ",
     80                                 "len(sequence_length):  ", seq_len->dim_size(0),
     81                                 " batch_size: ", batch_size));
     82     auto seq_len_t = seq_len->vec<int32>();
     83 
     84     OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
     85                 errors::InvalidArgument(
     86                     "labels_indices and labels_values must contain the "
     87                     "same number of rows, but saw shapes: ",
     88                     labels_indices->shape().DebugString(), " vs. ",
     89                     labels_values->shape().DebugString()));
     90 
     91     OP_REQUIRES(ctx, batch_size != 0,
     92                 errors::InvalidArgument("batch_size must not be 0"));
     93 
     94     // Figure out the maximum label length to use as sparse tensor dimension.
     95     auto labels_indices_t = labels_indices->matrix<int64>();
     96     int64 max_label_len = 0;
     97     for (int i = 0; i < labels_indices->dim_size(0); i++) {
     98       max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1);
     99     }
    100 
    101     TensorShape labels_shape({batch_size, max_label_len});
    102     std::vector<int64> order{0, 1};
    103     sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
    104                                    labels_shape, order);
    105 
    106     Status labels_sp_valid = labels_sp.IndicesValid();
    107     OP_REQUIRES(ctx, labels_sp_valid.ok(),
    108                 errors::InvalidArgument("label SparseTensor is not valid: ",
    109                                         labels_sp_valid.error_message()));
    110 
    111     ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
    112     for (const auto& g : labels_sp.group({0})) {  // iterate by batch
    113       const int64 batch_indices = g.group()[0];
    114       OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
    115                   errors::InvalidArgument("labels batch index must be between ",
    116                                           0, " and ", batch_size,
    117                                           " but saw: ", batch_indices));
    118 
    119       auto values = g.values<int32>();
    120       std::vector<int>* b_values = &labels_t[batch_indices];
    121       b_values->resize(values.size());
    122       for (int i = 0; i < values.size(); ++i) (*b_values)[i] = values(i);
    123     }
    124 
    125     OP_REQUIRES(ctx, static_cast<size_t>(batch_size) == labels_t.size(),
    126                 errors::InvalidArgument("len(labels) != batch_size.  ",
    127                                         "len(labels):  ", labels_t.size(),
    128                                         " batch_size: ", batch_size));
    129 
    130     for (int64 b = 0; b < batch_size; ++b) {
    131       OP_REQUIRES(
    132           ctx, seq_len_t(b) <= max_time,
    133           errors::InvalidArgument("sequence_length(", b, ") <= ", max_time));
    134     }
    135 
    136     Tensor* loss = nullptr;
    137     OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
    138     auto loss_t = loss->vec<float>();
    139 
    140     Tensor* gradient;
    141     OP_REQUIRES_OK(ctx,
    142                    ctx->allocate_output("gradient", inputs_shape, &gradient));
    143     auto gradient_t = gradient->tensor<float, 3>();
    144     auto inputs_t = inputs->tensor<float, 3>();
    145     std::vector<OutputMap> gradient_list_t;
    146     std::vector<InputMap> input_list_t;
    147 
    148     for (std::size_t t = 0; t < max_time; ++t) {
    149       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
    150                                 batch_size, num_classes);
    151       gradient_list_t.emplace_back(
    152           gradient_t.data() + t * batch_size * num_classes, batch_size,
    153           num_classes);
    154     }
    155 
    156     gradient_t.setZero();
    157 
    158     // Assumption: the blank index is num_classes - 1
    159     ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
    160     DeviceBase::CpuWorkerThreads workers =
    161         *ctx->device()->tensorflow_cpu_worker_threads();
    162     OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
    163                             seq_len_t, labels_t, input_list_t,
    164                             preprocess_collapse_repeated_, ctc_merge_repeated_,
    165                             ignore_longer_outputs_than_inputs_, &loss_t,
    166                             &gradient_list_t, &workers));
    167   }
    168 
    169  private:
    170   bool preprocess_collapse_repeated_;
    171   bool ctc_merge_repeated_;
    172   bool ignore_longer_outputs_than_inputs_;
    173 
    174   TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
    175 };
    176 
    177 REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
    178 
    179 }  // end namespace tensorflow
    180