1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor.pb.h" 28 #include "tensorflow/core/framework/tensor_util.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/framework/variant.h" 31 #include "tensorflow/core/framework/variant_encode_decode.h" 32 #include "tensorflow/core/kernels/reshape_util.h" 33 #include "tensorflow/core/lib/gtl/inlined_vector.h" 34 #include "tensorflow/core/lib/gtl/optional.h" 35 #include "tensorflow/core/util/sparse/sparse_tensor.h" 36 37 namespace tensorflow { 38 39 using sparse::SparseTensor; 40 41 template <typename T> 42 class SerializeSparseOp : public OpKernel { 43 public: 44 explicit SerializeSparseOp(OpKernelConstruction* context) 45 : OpKernel(context) {} 46 47 Status Initialize(Tensor* result); 48 Status Serialize(const Tensor& input, T* result); 49 50 void Compute(OpKernelContext* context) override { 51 const Tensor* input_indices; 52 const Tensor* input_values; 53 const Tensor* input_shape; 54 55 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 56 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 57 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 58 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 59 errors::InvalidArgument( 60 "Input indices should be a matrix but received shape ", 61 input_indices->shape().DebugString())); 62 63 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 64 errors::InvalidArgument( 65 "Input values should be a vector but received shape ", 66 input_values->shape().DebugString())); 67 68 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 69 errors::InvalidArgument( 70 "Input shape should be a vector but received shape ", 71 input_shape->shape().DebugString())); 72 73 Tensor serialized_sparse; 74 OP_REQUIRES_OK(context, Initialize(&serialized_sparse)); 75 76 auto serialized_sparse_t = serialized_sparse.vec<T>(); 77 OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0))); 78 OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1))); 79 OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2))); 80 81 context->set_output(0, serialized_sparse); 82 } 83 }; 84 85 template <> 86 Status SerializeSparseOp<string>::Initialize(Tensor* result) { 87 *result = Tensor(DT_STRING, TensorShape({3})); 88 return Status::OK(); 89 } 90 91 template <> 92 Status SerializeSparseOp<string>::Serialize(const Tensor& input, 93 string* result) { 94 TensorProto proto; 95 input.AsProtoTensorContent(&proto); 96 *result = proto.SerializeAsString(); 97 return Status::OK(); 98 } 99 100 REGISTER_KERNEL_BUILDER(Name("SerializeSparse") 101 .Device(DEVICE_CPU) 102 .TypeConstraint<string>("out_type"), 103 SerializeSparseOp<string>); 104 105 template <> 106 Status SerializeSparseOp<Variant>::Initialize(Tensor* result) { 107 *result = Tensor(DT_VARIANT, TensorShape({3})); 108 return Status::OK(); 109 } 110 111 template <> 112 Status SerializeSparseOp<Variant>::Serialize(const Tensor& input, 113 Variant* result) { 114 *result = input; 115 return Status::OK(); 116 } 117 118 REGISTER_KERNEL_BUILDER(Name("SerializeSparse") 119 .Device(DEVICE_CPU) 120 .TypeConstraint<Variant>("out_type"), 121 SerializeSparseOp<Variant>); 122 123 template <typename T> 124 class SerializeManySparseOpBase : public OpKernel { 125 public: 126 explicit SerializeManySparseOpBase(OpKernelConstruction* context) 127 : OpKernel(context) {} 128 129 void Compute(OpKernelContext* context) override {} 130 131 protected: 132 Status Initialize(const int64 n, Tensor* result); 133 Status Serialize(const Tensor& input, T* result); 134 }; 135 136 template <typename T, typename U> 137 class SerializeManySparseOp : public SerializeManySparseOpBase<U> { 138 public: 139 explicit SerializeManySparseOp(OpKernelConstruction* context) 140 : SerializeManySparseOpBase<U>(context) {} 141 142 void Compute(OpKernelContext* context) override { 143 const Tensor* input_indices; 144 const Tensor* input_values; 145 const Tensor* input_shape; 146 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 147 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 148 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 149 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 150 errors::InvalidArgument( 151 "Input indices should be a matrix but received shape ", 152 input_indices->shape().DebugString())); 153 154 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 155 errors::InvalidArgument( 156 "Input values should be a vector but received shape ", 157 input_values->shape().DebugString())); 158 159 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 160 errors::InvalidArgument( 161 "Input shape should be a vector but received shape ", 162 input_shape->shape().DebugString())); 163 164 int rank = input_shape->NumElements(); 165 166 OP_REQUIRES( 167 context, rank > 1, 168 errors::InvalidArgument( 169 "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); 170 171 TensorShape tensor_input_shape(input_shape->vec<int64>()); 172 gtl::InlinedVector<int64, 8> std_order(rank); 173 std::iota(std_order.begin(), std_order.end(), 0); 174 SparseTensor input_st(*input_indices, *input_values, tensor_input_shape, 175 std_order); 176 177 auto input_shape_t = input_shape->vec<int64>(); 178 const int64 N = input_shape_t(0); 179 Tensor serialized_sparse; 180 OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse)); 181 auto serialized_sparse_t = serialized_sparse.matrix<U>(); 182 183 OP_REQUIRES_OK(context, input_st.IndicesValid()); 184 185 // Initialize output with empty values and the proper shapes. 186 Tensor output_blank_indices(DT_INT64, {0, rank - 1}); 187 U serialized_indices; 188 OP_REQUIRES_OK(context, 189 this->Serialize(output_blank_indices, &serialized_indices)); 190 serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices); 191 192 Tensor output_blank_values(DataTypeToEnum<T>::value, {0}); 193 U serialized_values; 194 OP_REQUIRES_OK(context, 195 this->Serialize(output_blank_values, &serialized_values)); 196 serialized_sparse_t.template chip<1>(1).setConstant(serialized_values); 197 198 Tensor output_shape(DT_INT64, {rank - 1}); 199 auto output_shape_t = output_shape.vec<int64>(); 200 for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d); 201 U serialized_shape; 202 OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape)); 203 serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape); 204 205 // Get groups by minibatch dimension 206 sparse::GroupIterable minibatch = input_st.group({0}); 207 for (const auto& subset : minibatch) { 208 const int64 b = subset.group()[0]; 209 OP_REQUIRES( 210 context, b > -1 && b < N, 211 errors::InvalidArgument( 212 "Received unexpected column 0 value in input SparseTensor: ", b, 213 " < 0 or >= N (= ", N, ")")); 214 215 const auto indices = subset.indices(); 216 const auto values = subset.values<T>(); 217 const int64 num_entries = values.size(); 218 219 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); 220 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); 221 222 auto output_indices_t = output_indices.matrix<int64>(); 223 auto output_values_t = output_values.vec<T>(); 224 225 for (int i = 0; i < num_entries; ++i) { 226 for (int d = 1; d < rank; ++d) { 227 output_indices_t(i, d - 1) = indices(i, d); 228 } 229 output_values_t(i) = values(i); 230 } 231 232 OP_REQUIRES_OK( 233 context, this->Serialize(output_indices, &serialized_sparse_t(b, 0))); 234 OP_REQUIRES_OK( 235 context, this->Serialize(output_values, &serialized_sparse_t(b, 1))); 236 } 237 238 context->set_output(0, serialized_sparse); 239 } 240 }; 241 242 template <> 243 Status SerializeManySparseOpBase<string>::Initialize(const int64 n, 244 Tensor* result) { 245 *result = Tensor(DT_STRING, TensorShape({n, 3})); 246 return Status::OK(); 247 } 248 249 template <> 250 Status SerializeManySparseOpBase<string>::Serialize(const Tensor& input, 251 string* result) { 252 TensorProto proto; 253 input.AsProtoTensorContent(&proto); 254 *result = proto.SerializeAsString(); 255 return Status::OK(); 256 } 257 258 #define REGISTER_KERNELS(type) \ 259 REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \ 260 .Device(DEVICE_CPU) \ 261 .TypeConstraint<type>("T") \ 262 .TypeConstraint<string>("out_type"), \ 263 SerializeManySparseOp<type, string>) 264 265 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 266 #undef REGISTER_KERNELS 267 268 template <> 269 Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n, 270 Tensor* result) { 271 *result = Tensor(DT_VARIANT, TensorShape({n, 3})); 272 return Status::OK(); 273 } 274 275 template <> 276 Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input, 277 Variant* result) { 278 *result = input; 279 return Status::OK(); 280 } 281 282 #define REGISTER_KERNELS(type) \ 283 REGISTER_KERNEL_BUILDER(Name("SerializeManySparse") \ 284 .Device(DEVICE_CPU) \ 285 .TypeConstraint<type>("T") \ 286 .TypeConstraint<Variant>("out_type"), \ 287 SerializeManySparseOp<type, Variant>) 288 289 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 290 #undef REGISTER_KERNELS 291 292 template <typename T> 293 class DeserializeSparseOp : public OpKernel { 294 public: 295 explicit DeserializeSparseOp(OpKernelConstruction* context) 296 : OpKernel(context) { 297 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); 298 } 299 300 void Compute(OpKernelContext* context) override { 301 const Tensor& serialized_sparse = context->input(0); 302 const int ndims = serialized_sparse.shape().dims(); 303 304 OP_REQUIRES( 305 context, ndims > 0, 306 errors::InvalidArgument("Serialized sparse should have non-zero rank ", 307 serialized_sparse.shape().DebugString())); 308 309 OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3, 310 errors::InvalidArgument( 311 "Serialized sparse should have 3 as the last dimension ", 312 serialized_sparse.shape().DebugString())); 313 314 int num_sparse_tensors = 1; 315 for (int i = 0; i < ndims - 1; ++i) { 316 num_sparse_tensors *= serialized_sparse.shape().dim_size(i); 317 } 318 319 OP_REQUIRES( 320 context, num_sparse_tensors > 0, 321 errors::InvalidArgument( 322 "Serialized sparse should have at least 1 serialized tensor, " 323 "but has a zero dimension ", 324 serialized_sparse.shape().DebugString())); 325 326 if (num_sparse_tensors == 0 && serialized_sparse.shape().dims() == 1) { 327 // Special case with a single sparse tensor. We can avoid data 328 // motion in the Concat and Reshape. 329 const auto& serialized_sparse_t = serialized_sparse.vec<T>(); 330 331 Tensor output_indices; 332 Tensor output_values; 333 Tensor output_shape; 334 OP_REQUIRES_OK(context, 335 this->GetAndValidateSparseTensor( 336 serialized_sparse_t(0), serialized_sparse_t(1), 337 serialized_sparse_t(2), dtype_, 0 /* index */, 338 &output_indices, &output_values, &output_shape)); 339 context->set_output(0, output_indices); 340 context->set_output(1, output_values); 341 context->set_output(2, output_shape); 342 return; 343 } 344 345 std::vector<Tensor> indices; 346 std::vector<Tensor> values; 347 TensorShape shape; 348 indices.reserve(num_sparse_tensors); 349 values.reserve(num_sparse_tensors); 350 351 const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<T, 2>(); 352 for (int i = 0; i < num_sparse_tensors; ++i) { 353 Tensor output_indices; 354 Tensor output_values; 355 Tensor output_shape; 356 OP_REQUIRES_OK(context, 357 this->GetAndValidateSparseTensor( 358 serialized_sparse_t(i, 0), serialized_sparse_t(i, 1), 359 serialized_sparse_t(i, 2), dtype_, i, &output_indices, 360 &output_values, &output_shape)); 361 int64 num_entries = output_indices.dim_size(0); 362 int rank = output_indices.dim_size(1); 363 364 // Now we expand each SparseTensors' indices and shape by 365 // prefixing a dimension 366 Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank})); 367 const auto& output_indices_t = output_indices.matrix<int64>(); 368 auto expanded_indices_t = expanded_indices.matrix<int64>(); 369 expanded_indices_t.chip<1>(0).setZero(); 370 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); 371 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); 372 expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; 373 374 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); 375 const auto& output_shape_t = output_shape.vec<int64>(); 376 auto expanded_shape_t = expanded_shape.vec<int64>(); 377 expanded_shape_t(0) = 1; 378 std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); 379 380 TensorShape expanded_tensor_shape(expanded_shape.vec<int64>()); 381 382 indices.push_back(expanded_indices); 383 values.push_back(output_values); 384 if (i == 0) { 385 shape = expanded_tensor_shape; 386 } else { 387 OP_REQUIRES( 388 context, shape.dims() == expanded_tensor_shape.dims(), 389 errors::InvalidArgument( 390 "Inconsistent shape across SparseTensors: rank prior to " 391 "SparseTensor[", 392 i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i, 393 "] is: ", expanded_tensor_shape.dims() - 1)); 394 for (int j = 1; j < shape.dims(); ++j) { 395 // NOTE(mrry): For compatibility with the implementations of 396 // DeserializeManySparse, and many ops that generate 397 // SparseTensors to batch that do not have a fixed 398 // dense_shape (e.g. `tf.parse_single_example()`), we 399 // compute the maximum in each dimension to find the 400 // smallest dense_shape that bounds all of the input 401 // SparseTensors. 402 shape.set_dim(j, std::max(shape.dim_size(j), 403 expanded_tensor_shape.dim_size(j))); 404 } 405 } 406 } 407 408 // Dimension 0 is the primary dimension. 409 int rank = shape.dims(); 410 gtl::InlinedVector<int64, 8> std_order(rank); 411 std::iota(std_order.begin(), std_order.end(), 0); 412 413 std::vector<SparseTensor> tensors; 414 tensors.reserve(num_sparse_tensors); 415 for (int i = 0; i < num_sparse_tensors; ++i) { 416 tensors.emplace_back(indices[i], values[i], shape, std_order); 417 } 418 419 gtl::optional<SparseTensor> maybe_output; 420 #define HANDLE_TYPE(T) \ 421 case DataTypeToEnum<T>::value: { \ 422 maybe_output = SparseTensor::Concat<T>(tensors); \ 423 break; \ 424 } 425 426 switch (dtype_) { 427 TF_CALL_ALL_TYPES(HANDLE_TYPE); 428 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); 429 #undef HANDLE_TYPE 430 default: 431 OP_REQUIRES(context, false, 432 errors::Unimplemented( 433 "DeserializeSparse Unhandled data type: ", dtype_)); 434 } 435 DCHECK(maybe_output); 436 SparseTensor& output = maybe_output.value(); 437 438 // Compute the input shape for the reshape operation. 439 Tensor input_shape(DT_INT64, TensorShape({output.dims()})); 440 std::copy_n(output.shape().data(), output.dims(), 441 input_shape.vec<int64>().data()); 442 443 // Compute the target shape for the reshape operation. 444 Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2})); 445 for (int i = 0; i < ndims - 1; ++i) { 446 target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i); 447 } 448 for (int i = 0; i < output.dims() - 1; ++i) { 449 target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1]; 450 } 451 452 Tensor output_indices; 453 Tensor output_shape; 454 Reshape(context, output.indices(), input_shape, target_shape, 455 0 /* output indices index */, 2 /* output shape index */); 456 context->set_output(1, output.values()); 457 } 458 459 protected: 460 Status Deserialize(const T& serialized, Tensor* result); 461 462 Status GetAndValidateSparseTensor( 463 const T& serialized_indices, const T& serialized_values, 464 const T& serialized_shape, DataType values_dtype, int index, 465 Tensor* output_indices, Tensor* output_values, Tensor* output_shape) { 466 // Deserialize and validate the indices. 467 TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices)); 468 if (!TensorShapeUtils::IsMatrix(output_indices->shape())) { 469 return errors::InvalidArgument( 470 "Expected serialized_sparse[", index, 471 ", 0] to represent an index matrix but received shape ", 472 output_indices->shape().DebugString()); 473 } 474 int64 num_entries = output_indices->dim_size(0); 475 int rank = output_indices->dim_size(1); 476 477 // Deserialize and validate the values. 478 TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values)); 479 if (!TensorShapeUtils::IsVector(output_values->shape())) { 480 return errors::InvalidArgument( 481 "Expected serialized_sparse[", index, 482 ", 1] to represent a values vector but received shape ", 483 output_values->shape().DebugString()); 484 } 485 if (values_dtype != output_values->dtype()) { 486 return errors::InvalidArgument( 487 "Requested SparseTensor of type ", DataTypeString(values_dtype), 488 " but SparseTensor[", index, 489 "].values.dtype() == ", DataTypeString(output_values->dtype())); 490 } 491 if (num_entries != output_values->dim_size(0)) { 492 return errors::InvalidArgument( 493 "Expected row counts of SparseTensor[", index, 494 "].indices and SparseTensor[", index, 495 "].values to match but they do not: ", num_entries, " vs. ", 496 output_values->dim_size(0)); 497 } 498 499 // Deserialize and validate the shape. 500 TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape)); 501 if (!TensorShapeUtils::IsVector(output_shape->shape())) { 502 return errors::InvalidArgument( 503 "Expected serialized_sparse[", index, 504 ", 1] to be a shape vector but its shape is ", 505 output_shape->shape().DebugString()); 506 } 507 if (rank != output_shape->dim_size(0)) { 508 return errors::InvalidArgument("Expected column counts of SparseTensor[", 509 index, 510 "].indices to match size of SparseTensor[", 511 index, "].shape but they do not: ", rank, 512 " vs. ", output_shape->dim_size(0)); 513 } 514 return Status::OK(); 515 } 516 517 DataType dtype_; 518 }; 519 520 template <> 521 Status DeserializeSparseOp<string>::Deserialize(const string& serialized, 522 Tensor* result) { 523 TensorProto proto; 524 if (!ParseProtoUnlimited(&proto, serialized)) { 525 return errors::InvalidArgument("Could not parse serialized proto"); 526 } 527 Tensor tensor; 528 if (!tensor.FromProto(proto)) { 529 return errors::InvalidArgument("Could not construct tensor from proto"); 530 } 531 *result = tensor; 532 return Status::OK(); 533 } 534 535 REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") 536 .Device(DEVICE_CPU) 537 .TypeConstraint<string>("Tserialized"), 538 DeserializeSparseOp<string>) 539 540 REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU), 541 DeserializeSparseOp<string>) 542 543 template <> 544 Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized, 545 Tensor* result) { 546 *result = *serialized.get<Tensor>(); 547 return Status::OK(); 548 } 549 550 REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") 551 .Device(DEVICE_CPU) 552 .TypeConstraint<Variant>("Tserialized"), 553 DeserializeSparseOp<Variant>) 554 555 } // namespace tensorflow 556