Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <numeric>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor.pb.h"
     28 #include "tensorflow/core/framework/tensor_util.h"
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/framework/variant.h"
     31 #include "tensorflow/core/framework/variant_encode_decode.h"
     32 #include "tensorflow/core/kernels/reshape_util.h"
     33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     34 #include "tensorflow/core/lib/gtl/optional.h"
     35 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     36 
     37 namespace tensorflow {
     38 
     39 using sparse::SparseTensor;
     40 
     41 template <typename T>
     42 class SerializeSparseOp : public OpKernel {
     43  public:
     44   explicit SerializeSparseOp(OpKernelConstruction* context)
     45       : OpKernel(context) {}
     46 
     47   Status Initialize(Tensor* result);
     48   Status Serialize(const Tensor& input, T* result);
     49 
     50   void Compute(OpKernelContext* context) override {
     51     const Tensor* input_indices;
     52     const Tensor* input_values;
     53     const Tensor* input_shape;
     54 
     55     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
     56     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
     57     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
     58     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
     59                 errors::InvalidArgument(
     60                     "Input indices should be a matrix but received shape ",
     61                     input_indices->shape().DebugString()));
     62 
     63     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
     64                 errors::InvalidArgument(
     65                     "Input values should be a vector but received shape ",
     66                     input_values->shape().DebugString()));
     67 
     68     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
     69                 errors::InvalidArgument(
     70                     "Input shape should be a vector but received shape ",
     71                     input_shape->shape().DebugString()));
     72 
     73     Tensor serialized_sparse;
     74     OP_REQUIRES_OK(context, Initialize(&serialized_sparse));
     75 
     76     auto serialized_sparse_t = serialized_sparse.vec<T>();
     77     OP_REQUIRES_OK(context, Serialize(*input_indices, &serialized_sparse_t(0)));
     78     OP_REQUIRES_OK(context, Serialize(*input_values, &serialized_sparse_t(1)));
     79     OP_REQUIRES_OK(context, Serialize(*input_shape, &serialized_sparse_t(2)));
     80 
     81     context->set_output(0, serialized_sparse);
     82   }
     83 };
     84 
     85 template <>
     86 Status SerializeSparseOp<string>::Initialize(Tensor* result) {
     87   *result = Tensor(DT_STRING, TensorShape({3}));
     88   return Status::OK();
     89 }
     90 
     91 template <>
     92 Status SerializeSparseOp<string>::Serialize(const Tensor& input,
     93                                             string* result) {
     94   TensorProto proto;
     95   input.AsProtoTensorContent(&proto);
     96   *result = proto.SerializeAsString();
     97   return Status::OK();
     98 }
     99 
    100 REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
    101                             .Device(DEVICE_CPU)
    102                             .TypeConstraint<string>("out_type"),
    103                         SerializeSparseOp<string>);
    104 
    105 template <>
    106 Status SerializeSparseOp<Variant>::Initialize(Tensor* result) {
    107   *result = Tensor(DT_VARIANT, TensorShape({3}));
    108   return Status::OK();
    109 }
    110 
    111 template <>
    112 Status SerializeSparseOp<Variant>::Serialize(const Tensor& input,
    113                                              Variant* result) {
    114   *result = input;
    115   return Status::OK();
    116 }
    117 
    118 REGISTER_KERNEL_BUILDER(Name("SerializeSparse")
    119                             .Device(DEVICE_CPU)
    120                             .TypeConstraint<Variant>("out_type"),
    121                         SerializeSparseOp<Variant>);
    122 
    123 template <typename T>
    124 class SerializeManySparseOpBase : public OpKernel {
    125  public:
    126   explicit SerializeManySparseOpBase(OpKernelConstruction* context)
    127       : OpKernel(context) {}
    128 
    129   void Compute(OpKernelContext* context) override {}
    130 
    131  protected:
    132   Status Initialize(const int64 n, Tensor* result);
    133   Status Serialize(const Tensor& input, T* result);
    134 };
    135 
    136 template <typename T, typename U>
    137 class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
    138  public:
    139   explicit SerializeManySparseOp(OpKernelConstruction* context)
    140       : SerializeManySparseOpBase<U>(context) {}
    141 
    142   void Compute(OpKernelContext* context) override {
    143     const Tensor* input_indices;
    144     const Tensor* input_values;
    145     const Tensor* input_shape;
    146     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
    147     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
    148     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
    149     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
    150                 errors::InvalidArgument(
    151                     "Input indices should be a matrix but received shape ",
    152                     input_indices->shape().DebugString()));
    153 
    154     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
    155                 errors::InvalidArgument(
    156                     "Input values should be a vector but received shape ",
    157                     input_values->shape().DebugString()));
    158 
    159     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
    160                 errors::InvalidArgument(
    161                     "Input shape should be a vector but received shape ",
    162                     input_shape->shape().DebugString()));
    163 
    164     int rank = input_shape->NumElements();
    165 
    166     OP_REQUIRES(
    167         context, rank > 1,
    168         errors::InvalidArgument(
    169             "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
    170 
    171     TensorShape tensor_input_shape(input_shape->vec<int64>());
    172     gtl::InlinedVector<int64, 8> std_order(rank);
    173     std::iota(std_order.begin(), std_order.end(), 0);
    174     SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
    175                           std_order);
    176 
    177     auto input_shape_t = input_shape->vec<int64>();
    178     const int64 N = input_shape_t(0);
    179     Tensor serialized_sparse;
    180     OP_REQUIRES_OK(context, this->Initialize(N, &serialized_sparse));
    181     auto serialized_sparse_t = serialized_sparse.matrix<U>();
    182 
    183     OP_REQUIRES_OK(context, input_st.IndicesValid());
    184 
    185     // Initialize output with empty values and the proper shapes.
    186     Tensor output_blank_indices(DT_INT64, {0, rank - 1});
    187     U serialized_indices;
    188     OP_REQUIRES_OK(context,
    189                    this->Serialize(output_blank_indices, &serialized_indices));
    190     serialized_sparse_t.template chip<1>(0).setConstant(serialized_indices);
    191 
    192     Tensor output_blank_values(DataTypeToEnum<T>::value, {0});
    193     U serialized_values;
    194     OP_REQUIRES_OK(context,
    195                    this->Serialize(output_blank_values, &serialized_values));
    196     serialized_sparse_t.template chip<1>(1).setConstant(serialized_values);
    197 
    198     Tensor output_shape(DT_INT64, {rank - 1});
    199     auto output_shape_t = output_shape.vec<int64>();
    200     for (int d = 1; d < rank; d++) output_shape_t(d - 1) = input_shape_t(d);
    201     U serialized_shape;
    202     OP_REQUIRES_OK(context, this->Serialize(output_shape, &serialized_shape));
    203     serialized_sparse_t.template chip<1>(2).setConstant(serialized_shape);
    204 
    205     // Get groups by minibatch dimension
    206     sparse::GroupIterable minibatch = input_st.group({0});
    207     for (const auto& subset : minibatch) {
    208       const int64 b = subset.group()[0];
    209       OP_REQUIRES(
    210           context, b > -1 && b < N,
    211           errors::InvalidArgument(
    212               "Received unexpected column 0 value in input SparseTensor: ", b,
    213               " < 0 or >= N (= ", N, ")"));
    214 
    215       const auto indices = subset.indices();
    216       const auto values = subset.values<T>();
    217       const int64 num_entries = values.size();
    218 
    219       Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
    220       Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
    221 
    222       auto output_indices_t = output_indices.matrix<int64>();
    223       auto output_values_t = output_values.vec<T>();
    224 
    225       for (int i = 0; i < num_entries; ++i) {
    226         for (int d = 1; d < rank; ++d) {
    227           output_indices_t(i, d - 1) = indices(i, d);
    228         }
    229         output_values_t(i) = values(i);
    230       }
    231 
    232       OP_REQUIRES_OK(
    233           context, this->Serialize(output_indices, &serialized_sparse_t(b, 0)));
    234       OP_REQUIRES_OK(
    235           context, this->Serialize(output_values, &serialized_sparse_t(b, 1)));
    236     }
    237 
    238     context->set_output(0, serialized_sparse);
    239   }
    240 };
    241 
    242 template <>
    243 Status SerializeManySparseOpBase<string>::Initialize(const int64 n,
    244                                                      Tensor* result) {
    245   *result = Tensor(DT_STRING, TensorShape({n, 3}));
    246   return Status::OK();
    247 }
    248 
    249 template <>
    250 Status SerializeManySparseOpBase<string>::Serialize(const Tensor& input,
    251                                                     string* result) {
    252   TensorProto proto;
    253   input.AsProtoTensorContent(&proto);
    254   *result = proto.SerializeAsString();
    255   return Status::OK();
    256 }
    257 
    258 #define REGISTER_KERNELS(type)                                     \
    259   REGISTER_KERNEL_BUILDER(Name("SerializeManySparse")              \
    260                               .Device(DEVICE_CPU)                  \
    261                               .TypeConstraint<type>("T")           \
    262                               .TypeConstraint<string>("out_type"), \
    263                           SerializeManySparseOp<type, string>)
    264 
    265 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    266 #undef REGISTER_KERNELS
    267 
    268 template <>
    269 Status SerializeManySparseOpBase<Variant>::Initialize(const int64 n,
    270                                                       Tensor* result) {
    271   *result = Tensor(DT_VARIANT, TensorShape({n, 3}));
    272   return Status::OK();
    273 }
    274 
    275 template <>
    276 Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
    277                                                      Variant* result) {
    278   *result = input;
    279   return Status::OK();
    280 }
    281 
    282 #define REGISTER_KERNELS(type)                                      \
    283   REGISTER_KERNEL_BUILDER(Name("SerializeManySparse")               \
    284                               .Device(DEVICE_CPU)                   \
    285                               .TypeConstraint<type>("T")            \
    286                               .TypeConstraint<Variant>("out_type"), \
    287                           SerializeManySparseOp<type, Variant>)
    288 
    289 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    290 #undef REGISTER_KERNELS
    291 
    292 template <typename T>
    293 class DeserializeSparseOp : public OpKernel {
    294  public:
    295   explicit DeserializeSparseOp(OpKernelConstruction* context)
    296       : OpKernel(context) {
    297     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
    298   }
    299 
    300   void Compute(OpKernelContext* context) override {
    301     const Tensor& serialized_sparse = context->input(0);
    302     const int ndims = serialized_sparse.shape().dims();
    303 
    304     OP_REQUIRES(
    305         context, ndims > 0,
    306         errors::InvalidArgument("Serialized sparse should have non-zero rank ",
    307                                 serialized_sparse.shape().DebugString()));
    308 
    309     OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
    310                 errors::InvalidArgument(
    311                     "Serialized sparse should have 3 as the last dimension ",
    312                     serialized_sparse.shape().DebugString()));
    313 
    314     int num_sparse_tensors = 1;
    315     for (int i = 0; i < ndims - 1; ++i) {
    316       num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
    317     }
    318 
    319     OP_REQUIRES(
    320         context, num_sparse_tensors > 0,
    321         errors::InvalidArgument(
    322             "Serialized sparse should have at least 1 serialized tensor, "
    323             "but has a zero dimension ",
    324             serialized_sparse.shape().DebugString()));
    325 
    326     if (num_sparse_tensors == 0 && serialized_sparse.shape().dims() == 1) {
    327       // Special case with a single sparse tensor. We can avoid data
    328       // motion in the Concat and Reshape.
    329       const auto& serialized_sparse_t = serialized_sparse.vec<T>();
    330 
    331       Tensor output_indices;
    332       Tensor output_values;
    333       Tensor output_shape;
    334       OP_REQUIRES_OK(context,
    335                      this->GetAndValidateSparseTensor(
    336                          serialized_sparse_t(0), serialized_sparse_t(1),
    337                          serialized_sparse_t(2), dtype_, 0 /* index */,
    338                          &output_indices, &output_values, &output_shape));
    339       context->set_output(0, output_indices);
    340       context->set_output(1, output_values);
    341       context->set_output(2, output_shape);
    342       return;
    343     }
    344 
    345     std::vector<Tensor> indices;
    346     std::vector<Tensor> values;
    347     TensorShape shape;
    348     indices.reserve(num_sparse_tensors);
    349     values.reserve(num_sparse_tensors);
    350 
    351     const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<T, 2>();
    352     for (int i = 0; i < num_sparse_tensors; ++i) {
    353       Tensor output_indices;
    354       Tensor output_values;
    355       Tensor output_shape;
    356       OP_REQUIRES_OK(context,
    357                      this->GetAndValidateSparseTensor(
    358                          serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
    359                          serialized_sparse_t(i, 2), dtype_, i, &output_indices,
    360                          &output_values, &output_shape));
    361       int64 num_entries = output_indices.dim_size(0);
    362       int rank = output_indices.dim_size(1);
    363 
    364       // Now we expand each SparseTensors' indices and shape by
    365       // prefixing a dimension
    366       Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
    367       const auto& output_indices_t = output_indices.matrix<int64>();
    368       auto expanded_indices_t = expanded_indices.matrix<int64>();
    369       expanded_indices_t.chip<1>(0).setZero();
    370       Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
    371       Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
    372       expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
    373 
    374       Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
    375       const auto& output_shape_t = output_shape.vec<int64>();
    376       auto expanded_shape_t = expanded_shape.vec<int64>();
    377       expanded_shape_t(0) = 1;
    378       std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
    379 
    380       TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
    381 
    382       indices.push_back(expanded_indices);
    383       values.push_back(output_values);
    384       if (i == 0) {
    385         shape = expanded_tensor_shape;
    386       } else {
    387         OP_REQUIRES(
    388             context, shape.dims() == expanded_tensor_shape.dims(),
    389             errors::InvalidArgument(
    390                 "Inconsistent shape across SparseTensors: rank prior to "
    391                 "SparseTensor[",
    392                 i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
    393                 "] is: ", expanded_tensor_shape.dims() - 1));
    394         for (int j = 1; j < shape.dims(); ++j) {
    395           // NOTE(mrry): For compatibility with the implementations of
    396           // DeserializeManySparse, and many ops that generate
    397           // SparseTensors to batch that do not have a fixed
    398           // dense_shape (e.g. `tf.parse_single_example()`), we
    399           // compute the maximum in each dimension to find the
    400           // smallest dense_shape that bounds all of the input
    401           // SparseTensors.
    402           shape.set_dim(j, std::max(shape.dim_size(j),
    403                                     expanded_tensor_shape.dim_size(j)));
    404         }
    405       }
    406     }
    407 
    408     // Dimension 0 is the primary dimension.
    409     int rank = shape.dims();
    410     gtl::InlinedVector<int64, 8> std_order(rank);
    411     std::iota(std_order.begin(), std_order.end(), 0);
    412 
    413     std::vector<SparseTensor> tensors;
    414     tensors.reserve(num_sparse_tensors);
    415     for (int i = 0; i < num_sparse_tensors; ++i) {
    416       tensors.emplace_back(indices[i], values[i], shape, std_order);
    417     }
    418 
    419     gtl::optional<SparseTensor> maybe_output;
    420 #define HANDLE_TYPE(T)                               \
    421   case DataTypeToEnum<T>::value: {                   \
    422     maybe_output = SparseTensor::Concat<T>(tensors); \
    423     break;                                           \
    424   }
    425 
    426     switch (dtype_) {
    427       TF_CALL_ALL_TYPES(HANDLE_TYPE);
    428       TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
    429 #undef HANDLE_TYPE
    430       default:
    431         OP_REQUIRES(context, false,
    432                     errors::Unimplemented(
    433                         "DeserializeSparse Unhandled data type: ", dtype_));
    434     }
    435     DCHECK(maybe_output);
    436     SparseTensor& output = maybe_output.value();
    437 
    438     // Compute the input shape for the reshape operation.
    439     Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
    440     std::copy_n(output.shape().data(), output.dims(),
    441                 input_shape.vec<int64>().data());
    442 
    443     // Compute the target shape for the reshape operation.
    444     Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
    445     for (int i = 0; i < ndims - 1; ++i) {
    446       target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
    447     }
    448     for (int i = 0; i < output.dims() - 1; ++i) {
    449       target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
    450     }
    451 
    452     Tensor output_indices;
    453     Tensor output_shape;
    454     Reshape(context, output.indices(), input_shape, target_shape,
    455             0 /* output indices index */, 2 /* output shape index */);
    456     context->set_output(1, output.values());
    457   }
    458 
    459  protected:
    460   Status Deserialize(const T& serialized, Tensor* result);
    461 
    462   Status GetAndValidateSparseTensor(
    463       const T& serialized_indices, const T& serialized_values,
    464       const T& serialized_shape, DataType values_dtype, int index,
    465       Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
    466     // Deserialize and validate the indices.
    467     TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
    468     if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
    469       return errors::InvalidArgument(
    470           "Expected serialized_sparse[", index,
    471           ", 0] to represent an index matrix but received shape ",
    472           output_indices->shape().DebugString());
    473     }
    474     int64 num_entries = output_indices->dim_size(0);
    475     int rank = output_indices->dim_size(1);
    476 
    477     // Deserialize and validate the values.
    478     TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
    479     if (!TensorShapeUtils::IsVector(output_values->shape())) {
    480       return errors::InvalidArgument(
    481           "Expected serialized_sparse[", index,
    482           ", 1] to represent a values vector but received shape ",
    483           output_values->shape().DebugString());
    484     }
    485     if (values_dtype != output_values->dtype()) {
    486       return errors::InvalidArgument(
    487           "Requested SparseTensor of type ", DataTypeString(values_dtype),
    488           " but SparseTensor[", index,
    489           "].values.dtype() == ", DataTypeString(output_values->dtype()));
    490     }
    491     if (num_entries != output_values->dim_size(0)) {
    492       return errors::InvalidArgument(
    493           "Expected row counts of SparseTensor[", index,
    494           "].indices and SparseTensor[", index,
    495           "].values to match but they do not: ", num_entries, " vs. ",
    496           output_values->dim_size(0));
    497     }
    498 
    499     // Deserialize and validate the shape.
    500     TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
    501     if (!TensorShapeUtils::IsVector(output_shape->shape())) {
    502       return errors::InvalidArgument(
    503           "Expected serialized_sparse[", index,
    504           ", 1] to be a shape vector but its shape is ",
    505           output_shape->shape().DebugString());
    506     }
    507     if (rank != output_shape->dim_size(0)) {
    508       return errors::InvalidArgument("Expected column counts of SparseTensor[",
    509                                      index,
    510                                      "].indices to match size of SparseTensor[",
    511                                      index, "].shape but they do not: ", rank,
    512                                      " vs. ", output_shape->dim_size(0));
    513     }
    514     return Status::OK();
    515   }
    516 
    517   DataType dtype_;
    518 };
    519 
    520 template <>
    521 Status DeserializeSparseOp<string>::Deserialize(const string& serialized,
    522                                                 Tensor* result) {
    523   TensorProto proto;
    524   if (!ParseProtoUnlimited(&proto, serialized)) {
    525     return errors::InvalidArgument("Could not parse serialized proto");
    526   }
    527   Tensor tensor;
    528   if (!tensor.FromProto(proto)) {
    529     return errors::InvalidArgument("Could not construct tensor from proto");
    530   }
    531   *result = tensor;
    532   return Status::OK();
    533 }
    534 
    535 REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
    536                             .Device(DEVICE_CPU)
    537                             .TypeConstraint<string>("Tserialized"),
    538                         DeserializeSparseOp<string>)
    539 
    540 REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
    541                         DeserializeSparseOp<string>)
    542 
    543 template <>
    544 Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized,
    545                                                  Tensor* result) {
    546   *result = *serialized.get<Tensor>();
    547   return Status::OK();
    548 }
    549 
    550 REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
    551                             .Device(DEVICE_CPU)
    552                             .TypeConstraint<Variant>("Tserialized"),
    553                         DeserializeSparseOp<Variant>)
    554 
    555 }  // namespace tensorflow
    556