Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <limits>
     21 
     22 #include <vector>
     23 #include "tensorflow/core/common_runtime/device.h"
     24 #include "tensorflow/core/framework/op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/lib/gtl/edit_distance.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     33 
     34 namespace tensorflow {
     35 
     36 namespace {
     37 
     38 Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices,
     39                       const Tensor& hypothesis_values,
     40                       const Tensor& hypothesis_shape,
     41                       const Tensor& truth_indices, const Tensor& truth_values,
     42                       const Tensor& truth_shape) {
     43   if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape()))
     44     return errors::InvalidArgument(
     45         "hypothesis_indices should be a matrix, but got shape: ",
     46         hypothesis_indices.shape().DebugString());
     47   if (!TensorShapeUtils::IsMatrix(truth_indices.shape()))
     48     return errors::InvalidArgument(
     49         "truth_indices should be a matrix, but got shape: ",
     50         truth_indices.shape().DebugString());
     51   if (!TensorShapeUtils::IsVector(hypothesis_values.shape()))
     52     return errors::InvalidArgument(
     53         "hypothesis_values should be a vector, but got shape: ",
     54         hypothesis_values.shape().DebugString());
     55   if (!TensorShapeUtils::IsVector(truth_values.shape()))
     56     return errors::InvalidArgument(
     57         "truth_values should be a vector, but got shape: ",
     58         truth_values.shape().DebugString());
     59   if (!TensorShapeUtils::IsVector(hypothesis_shape.shape()))
     60     return errors::InvalidArgument(
     61         "hypothesis_shape should be a vector, but got shape: ",
     62         hypothesis_shape.shape().DebugString());
     63   if (!TensorShapeUtils::IsVector(truth_shape.shape()))
     64     return errors::InvalidArgument(
     65         "truth_shape should be a vector, but got shape: ",
     66         truth_shape.shape().DebugString());
     67   if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1))
     68     return errors::InvalidArgument(
     69         "Expected hypothesis_shape.NumElements == "
     70         "#cols(hypothesis_indices), their shapes are: ",
     71         hypothesis_shape.shape().DebugString(), " and ",
     72         hypothesis_indices.shape().DebugString());
     73   if (truth_shape.NumElements() < 2)
     74     return errors::InvalidArgument(
     75         "Input SparseTensors must have rank at least 2, but truth_shape "
     76         "rank is: ",
     77         truth_shape.NumElements());
     78   if (truth_shape.NumElements() != truth_indices.dim_size(1))
     79     return errors::InvalidArgument(
     80         "Expected truth_shape.NumElements == "
     81         "#cols(truth_indices), their shapes are: ",
     82         truth_shape.shape().DebugString(), " and ",
     83         truth_indices.shape().DebugString());
     84   if (truth_shape.NumElements() != hypothesis_shape.NumElements())
     85     return errors::InvalidArgument(
     86         "Expected truth and hypothesis to have matching ranks, but "
     87         "their shapes are: ",
     88         truth_shape.shape().DebugString(), " and ",
     89         hypothesis_shape.shape().DebugString());
     90 
     91   return Status::OK();
     92 }
     93 
     94 }  // namespace
     95 
     96 template <typename T>
     97 class EditDistanceOp : public OpKernel {
     98  public:
     99   explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    100     OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize", &normalize_));
    101   }
    102 
    103   void Compute(OpKernelContext* ctx) override {
    104     const Tensor* hypothesis_indices;
    105     const Tensor* hypothesis_values;
    106     const Tensor* hypothesis_shape;
    107     const Tensor* truth_indices;
    108     const Tensor* truth_values;
    109     const Tensor* truth_shape;
    110     OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices", &hypothesis_indices));
    111     OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values", &hypothesis_values));
    112     OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape", &hypothesis_shape));
    113     OP_REQUIRES_OK(ctx, ctx->input("truth_indices", &truth_indices));
    114     OP_REQUIRES_OK(ctx, ctx->input("truth_values", &truth_values));
    115     OP_REQUIRES_OK(ctx, ctx->input("truth_shape", &truth_shape));
    116 
    117     OP_REQUIRES_OK(
    118         ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values,
    119                             *hypothesis_shape, *truth_indices, *truth_values,
    120                             *truth_shape));
    121 
    122     TensorShape hypothesis_st_shape;
    123     OP_REQUIRES_OK(
    124         ctx, TensorShapeUtils::MakeShape(hypothesis_shape->vec<int64>().data(),
    125                                          hypothesis_shape->NumElements(),
    126                                          &hypothesis_st_shape));
    127     TensorShape truth_st_shape;
    128     OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
    129                             truth_shape->vec<int64>().data(),
    130                             truth_shape->NumElements(), &truth_st_shape));
    131 
    132     // Assume indices are sorted in row-major order.
    133     std::vector<int64> sorted_order(truth_st_shape.dims());
    134     std::iota(sorted_order.begin(), sorted_order.end(), 0);
    135 
    136     sparse::SparseTensor hypothesis(*hypothesis_indices, *hypothesis_values,
    137                                     hypothesis_st_shape, sorted_order);
    138     sparse::SparseTensor truth(*truth_indices, *truth_values, truth_st_shape,
    139                                sorted_order);
    140 
    141     // Group dims 0, 1, ..., RANK - 1.  The very last dim is assumed
    142     // to store the variable length sequences.
    143     std::vector<int64> group_dims(truth_st_shape.dims() - 1);
    144     std::iota(group_dims.begin(), group_dims.end(), 0);
    145 
    146     TensorShape output_shape;
    147     for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) {
    148       output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d),
    149                                    truth_st_shape.dim_size(d)));
    150     }
    151 
    152     Tensor* output = nullptr;
    153     OP_REQUIRES_OK(ctx, ctx->allocate_output("output", output_shape, &output));
    154     auto output_t = output->flat<float>();
    155     output_t.setZero();
    156 
    157     std::vector<int64> output_strides(output_shape.dims());
    158     output_strides[output_shape.dims() - 1] = 1;
    159     for (int d = output_shape.dims() - 2; d >= 0; --d) {
    160       output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
    161     }
    162 
    163     auto hypothesis_grouper = hypothesis.group(group_dims);
    164     auto truth_grouper = truth.group(group_dims);
    165 
    166     auto hypothesis_iter = hypothesis_grouper.begin();
    167     auto truth_iter = truth_grouper.begin();
    168 
    169     auto cmp = std::equal_to<T>();
    170 
    171     while (hypothesis_iter != hypothesis_grouper.end() &&
    172            truth_iter != truth_grouper.end()) {
    173       sparse::Group truth_i = *truth_iter;
    174       sparse::Group hypothesis_j = *hypothesis_iter;
    175       std::vector<int64> g_truth = truth_i.group();
    176       std::vector<int64> g_hypothesis = hypothesis_j.group();
    177       auto truth_seq = truth_i.values<T>();
    178       auto hypothesis_seq = hypothesis_j.values<T>();
    179 
    180       if (g_truth == g_hypothesis) {
    181         auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
    182                                       output_strides.begin(), int64{0});
    183         output_t(loc) =
    184             gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp);
    185         if (normalize_) output_t(loc) /= truth_seq.size();
    186 
    187         ++hypothesis_iter;
    188         ++truth_iter;
    189       } else if (g_truth > g_hypothesis) {  // zero-length truth
    190         auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
    191                                       output_strides.begin(), int64{0});
    192         output_t(loc) = hypothesis_seq.size();
    193         if (normalize_ && output_t(loc) != 0.0f) {
    194           output_t(loc) = std::numeric_limits<float>::infinity();
    195         }
    196         ++hypothesis_iter;
    197       } else {  // zero-length hypothesis
    198         auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
    199                                       output_strides.begin(), int64{0});
    200         output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
    201         ++truth_iter;
    202       }
    203     }
    204     while (hypothesis_iter != hypothesis_grouper.end()) {  // zero-length truths
    205       sparse::Group hypothesis_j = *hypothesis_iter;
    206       std::vector<int64> g_hypothesis = hypothesis_j.group();
    207       auto hypothesis_seq = hypothesis_j.values<T>();
    208       auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
    209                                     output_strides.begin(), int64{0});
    210       output_t(loc) = hypothesis_seq.size();
    211       if (normalize_ && output_t(loc) != 0.0f) {
    212         output_t(loc) = std::numeric_limits<float>::infinity();
    213       }
    214       ++hypothesis_iter;
    215     }
    216     while (truth_iter != truth_grouper.end()) {  // missing hypotheses
    217       sparse::Group truth_i = *truth_iter;
    218       std::vector<int64> g_truth = truth_i.group();
    219       auto truth_seq = truth_i.values<T>();
    220       auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
    221                                     output_strides.begin(), int64{0});
    222       output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
    223       ++truth_iter;
    224     }
    225   }
    226 
    227  private:
    228   bool normalize_;
    229 
    230   TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp);
    231 };
    232 
    233 #define REGISTER_CPU_KERNEL(T)                                        \
    234   REGISTER_KERNEL_BUILDER(                                            \
    235       Name("EditDistance").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    236       EditDistanceOp<T>);
    237 
    238 TF_CALL_POD_STRING_TYPES(REGISTER_CPU_KERNEL);
    239 
    240 #undef REGISTER_CPU_KERNEL
    241 
    242 }  // end namespace tensorflow
    243