1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <limits> 21 22 #include <vector> 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/framework/op.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/gtl/edit_distance.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/util/sparse/sparse_tensor.h" 33 34 namespace tensorflow { 35 36 namespace { 37 38 Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices, 39 const Tensor& hypothesis_values, 40 const Tensor& hypothesis_shape, 41 const Tensor& truth_indices, const Tensor& truth_values, 42 const Tensor& truth_shape) { 43 if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape())) 44 return errors::InvalidArgument( 45 "hypothesis_indices should be a matrix, but got shape: ", 46 hypothesis_indices.shape().DebugString()); 47 if (!TensorShapeUtils::IsMatrix(truth_indices.shape())) 48 return errors::InvalidArgument( 49 "truth_indices should be a matrix, but got shape: ", 50 truth_indices.shape().DebugString()); 51 if (!TensorShapeUtils::IsVector(hypothesis_values.shape())) 52 return errors::InvalidArgument( 53 "hypothesis_values should be a vector, but got shape: ", 54 hypothesis_values.shape().DebugString()); 55 if (!TensorShapeUtils::IsVector(truth_values.shape())) 56 return errors::InvalidArgument( 57 "truth_values should be a vector, but got shape: ", 58 truth_values.shape().DebugString()); 59 if (!TensorShapeUtils::IsVector(hypothesis_shape.shape())) 60 return errors::InvalidArgument( 61 "hypothesis_shape should be a vector, but got shape: ", 62 hypothesis_shape.shape().DebugString()); 63 if (!TensorShapeUtils::IsVector(truth_shape.shape())) 64 return errors::InvalidArgument( 65 "truth_shape should be a vector, but got shape: ", 66 truth_shape.shape().DebugString()); 67 if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1)) 68 return errors::InvalidArgument( 69 "Expected hypothesis_shape.NumElements == " 70 "#cols(hypothesis_indices), their shapes are: ", 71 hypothesis_shape.shape().DebugString(), " and ", 72 hypothesis_indices.shape().DebugString()); 73 if (truth_shape.NumElements() < 2) 74 return errors::InvalidArgument( 75 "Input SparseTensors must have rank at least 2, but truth_shape " 76 "rank is: ", 77 truth_shape.NumElements()); 78 if (truth_shape.NumElements() != truth_indices.dim_size(1)) 79 return errors::InvalidArgument( 80 "Expected truth_shape.NumElements == " 81 "#cols(truth_indices), their shapes are: ", 82 truth_shape.shape().DebugString(), " and ", 83 truth_indices.shape().DebugString()); 84 if (truth_shape.NumElements() != hypothesis_shape.NumElements()) 85 return errors::InvalidArgument( 86 "Expected truth and hypothesis to have matching ranks, but " 87 "their shapes are: ", 88 truth_shape.shape().DebugString(), " and ", 89 hypothesis_shape.shape().DebugString()); 90 91 return Status::OK(); 92 } 93 94 } // namespace 95 96 template <typename T> 97 class EditDistanceOp : public OpKernel { 98 public: 99 explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 100 OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize", &normalize_)); 101 } 102 103 void Compute(OpKernelContext* ctx) override { 104 const Tensor* hypothesis_indices; 105 const Tensor* hypothesis_values; 106 const Tensor* hypothesis_shape; 107 const Tensor* truth_indices; 108 const Tensor* truth_values; 109 const Tensor* truth_shape; 110 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices", &hypothesis_indices)); 111 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values", &hypothesis_values)); 112 OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape", &hypothesis_shape)); 113 OP_REQUIRES_OK(ctx, ctx->input("truth_indices", &truth_indices)); 114 OP_REQUIRES_OK(ctx, ctx->input("truth_values", &truth_values)); 115 OP_REQUIRES_OK(ctx, ctx->input("truth_shape", &truth_shape)); 116 117 OP_REQUIRES_OK( 118 ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values, 119 *hypothesis_shape, *truth_indices, *truth_values, 120 *truth_shape)); 121 122 TensorShape hypothesis_st_shape; 123 OP_REQUIRES_OK( 124 ctx, TensorShapeUtils::MakeShape(hypothesis_shape->vec<int64>().data(), 125 hypothesis_shape->NumElements(), 126 &hypothesis_st_shape)); 127 TensorShape truth_st_shape; 128 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( 129 truth_shape->vec<int64>().data(), 130 truth_shape->NumElements(), &truth_st_shape)); 131 132 // Assume indices are sorted in row-major order. 133 std::vector<int64> sorted_order(truth_st_shape.dims()); 134 std::iota(sorted_order.begin(), sorted_order.end(), 0); 135 136 sparse::SparseTensor hypothesis(*hypothesis_indices, *hypothesis_values, 137 hypothesis_st_shape, sorted_order); 138 sparse::SparseTensor truth(*truth_indices, *truth_values, truth_st_shape, 139 sorted_order); 140 141 // Group dims 0, 1, ..., RANK - 1. The very last dim is assumed 142 // to store the variable length sequences. 143 std::vector<int64> group_dims(truth_st_shape.dims() - 1); 144 std::iota(group_dims.begin(), group_dims.end(), 0); 145 146 TensorShape output_shape; 147 for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) { 148 output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d), 149 truth_st_shape.dim_size(d))); 150 } 151 152 Tensor* output = nullptr; 153 OP_REQUIRES_OK(ctx, ctx->allocate_output("output", output_shape, &output)); 154 auto output_t = output->flat<float>(); 155 output_t.setZero(); 156 157 std::vector<int64> output_strides(output_shape.dims()); 158 output_strides[output_shape.dims() - 1] = 1; 159 for (int d = output_shape.dims() - 2; d >= 0; --d) { 160 output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); 161 } 162 163 auto hypothesis_grouper = hypothesis.group(group_dims); 164 auto truth_grouper = truth.group(group_dims); 165 166 auto hypothesis_iter = hypothesis_grouper.begin(); 167 auto truth_iter = truth_grouper.begin(); 168 169 auto cmp = std::equal_to<T>(); 170 171 while (hypothesis_iter != hypothesis_grouper.end() && 172 truth_iter != truth_grouper.end()) { 173 sparse::Group truth_i = *truth_iter; 174 sparse::Group hypothesis_j = *hypothesis_iter; 175 std::vector<int64> g_truth = truth_i.group(); 176 std::vector<int64> g_hypothesis = hypothesis_j.group(); 177 auto truth_seq = truth_i.values<T>(); 178 auto hypothesis_seq = hypothesis_j.values<T>(); 179 180 if (g_truth == g_hypothesis) { 181 auto loc = std::inner_product(g_truth.begin(), g_truth.end(), 182 output_strides.begin(), int64{0}); 183 output_t(loc) = 184 gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp); 185 if (normalize_) output_t(loc) /= truth_seq.size(); 186 187 ++hypothesis_iter; 188 ++truth_iter; 189 } else if (g_truth > g_hypothesis) { // zero-length truth 190 auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), 191 output_strides.begin(), int64{0}); 192 output_t(loc) = hypothesis_seq.size(); 193 if (normalize_ && output_t(loc) != 0.0f) { 194 output_t(loc) = std::numeric_limits<float>::infinity(); 195 } 196 ++hypothesis_iter; 197 } else { // zero-length hypothesis 198 auto loc = std::inner_product(g_truth.begin(), g_truth.end(), 199 output_strides.begin(), int64{0}); 200 output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); 201 ++truth_iter; 202 } 203 } 204 while (hypothesis_iter != hypothesis_grouper.end()) { // zero-length truths 205 sparse::Group hypothesis_j = *hypothesis_iter; 206 std::vector<int64> g_hypothesis = hypothesis_j.group(); 207 auto hypothesis_seq = hypothesis_j.values<T>(); 208 auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(), 209 output_strides.begin(), int64{0}); 210 output_t(loc) = hypothesis_seq.size(); 211 if (normalize_ && output_t(loc) != 0.0f) { 212 output_t(loc) = std::numeric_limits<float>::infinity(); 213 } 214 ++hypothesis_iter; 215 } 216 while (truth_iter != truth_grouper.end()) { // missing hypotheses 217 sparse::Group truth_i = *truth_iter; 218 std::vector<int64> g_truth = truth_i.group(); 219 auto truth_seq = truth_i.values<T>(); 220 auto loc = std::inner_product(g_truth.begin(), g_truth.end(), 221 output_strides.begin(), int64{0}); 222 output_t(loc) = (normalize_) ? 1.0 : truth_seq.size(); 223 ++truth_iter; 224 } 225 } 226 227 private: 228 bool normalize_; 229 230 TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp); 231 }; 232 233 #define REGISTER_CPU_KERNEL(T) \ 234 REGISTER_KERNEL_BUILDER( \ 235 Name("EditDistance").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 236 EditDistanceOp<T>); 237 238 TF_CALL_POD_STRING_TYPES(REGISTER_CPU_KERNEL); 239 240 #undef REGISTER_CPU_KERNEL 241 242 } // end namespace tensorflow 243