Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See core/ops/sparse_ops.cc for documentation.
     17 //
     18 // NOTE: the operations in this file only are suitable for execution
     19 // on CPUs.
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 
     25 #include <numeric>
     26 #include <sstream>
     27 #include <string>
     28 #include <unordered_map>
     29 #include <utility>
     30 
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/register_types.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/types.h"
     35 #include "tensorflow/core/lib/core/status.h"
     36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     37 #include "tensorflow/core/lib/strings/stringprintf.h"
     38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     39 
     40 namespace tensorflow {
     41 
     42 // Operator to convert sparse representations to dense.
     43 template <typename T, typename Index>
     44 class SparseToDense : public OpKernel {
     45  public:
     46   explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {
     47     OP_REQUIRES_OK(context,
     48                    context->GetAttr("validate_indices", &validate_indices_));
     49   }
     50 
     51   void Compute(OpKernelContext* c) override {
     52     // sparse_indices
     53     const Tensor& indices = c->input(0);
     54     OP_REQUIRES(c, indices.dims() <= 2,
     55                 errors::InvalidArgument(
     56                     "sparse_indices should be a scalar, vector, or matrix, "
     57                     "got shape ",
     58                     indices.shape().DebugString()));
     59     const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
     60     const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;
     61 
     62     // output_shape
     63     const Tensor& output_shape = c->input(1);
     64     OP_REQUIRES(
     65         c, IsLegacyVector(output_shape.shape()),
     66         errors::InvalidArgument("output_shape should be a vector, got shape ",
     67                                 output_shape.shape().DebugString()));
     68     OP_REQUIRES(c, output_shape.NumElements() == num_dims,
     69                 errors::InvalidArgument(
     70                     "output_shape has incorrect number of elements: ",
     71                     output_shape.NumElements(), " should be: ", num_dims));
     72 
     73     // sparse_values
     74     const Tensor& sparse_values = c->input(2);
     75     const int64 num_values = sparse_values.NumElements();
     76     OP_REQUIRES(c,
     77                 sparse_values.dims() == 0 ||
     78                     (sparse_values.dims() == 1 && num_values == num_elems),
     79                 errors::InvalidArgument("sparse_values has incorrect shape ",
     80                                         sparse_values.shape().DebugString(),
     81                                         ", should be [] or [", num_elems, "]"));
     82 
     83     // default_value
     84     const Tensor& default_value = c->input(3);
     85     OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()),
     86                 errors::InvalidArgument("default_value should be a scalar."));
     87 
     88     auto output_shape_vec = output_shape.flat<Index>();
     89     TensorShape output_tensor_shape;
     90     OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(),
     91                                                   output_shape_vec.size(),
     92                                                   &output_tensor_shape));
     93     Tensor* output = nullptr;
     94     OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output));
     95 
     96     TensorShape ix_shape({num_elems, num_dims});
     97     Tensor indices_shaped(DT_INT64, ix_shape);
     98     if (indices.dtype() == DT_INT64) {
     99       CHECK(indices_shaped.CopyFrom(indices, ix_shape));
    100     } else {
    101       indices_shaped.matrix<int64>() =
    102           indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>();
    103     }
    104 
    105     // If we received a scalar, we'll need to create a new
    106     // tensor with copies of the values as a vec.
    107     // TODO(ebrevdo): find a way to avoid this temp allocation.
    108     Tensor sparse_values_b;
    109 
    110     if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
    111       OP_REQUIRES_OK(
    112           c, c->allocate_temp(DataTypeToEnum<T>::value,
    113                               TensorShape({num_elems}), &sparse_values_b));
    114       sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()());
    115     } else {
    116       sparse_values_b = sparse_values;
    117     }
    118 
    119     // Assume SparseTensor is lexicographically sorted.
    120     gtl::InlinedVector<int64, 8> order(output->shape().dims());
    121     std::iota(order.begin(), order.end(), 0);
    122     sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(),
    123                             order);
    124 
    125     if (validate_indices_) {
    126       OP_REQUIRES_OK(c, st.IndicesValid());
    127     }
    128 
    129     output->flat<T>().setConstant(default_value.scalar<T>()());
    130     OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
    131                 errors::InvalidArgument(
    132                     "Indices are not valid (out of bounds).  Shape: ",
    133                     output->shape().DebugString()));
    134   }
    135 
    136  private:
    137   bool validate_indices_;
    138 };
    139 
    140 #define REGISTER_KERNELS(type, index_type)                             \
    141   REGISTER_KERNEL_BUILDER(Name("SparseToDense")                        \
    142                               .Device(DEVICE_CPU)                      \
    143                               .TypeConstraint<type>("T")               \
    144                               .TypeConstraint<index_type>("Tindices"), \
    145                           SparseToDense<type, index_type>);
    146 
    147 #define REGISTER_KERNELS_ALL(type) \
    148   REGISTER_KERNELS(type, int32);   \
    149   REGISTER_KERNELS(type, int64);
    150 
    151 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL);
    152 REGISTER_KERNELS_ALL(bool);
    153 REGISTER_KERNELS_ALL(string);
    154 
    155 #undef REGISTER_KERNELS_ALL
    156 #undef REGISTER_KERNELS
    157 
    158 }  // namespace tensorflow
    159