Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/kernels/sparse_tensor_dense_add_op.h"
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_util.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 // NOTE: does not support GPU yet.
     31 
     32 namespace {
     33 
     34 template <typename Index>
     35 Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values,
     36                       const Tensor *a_shape, const Tensor *b) {
     37   if (!TensorShapeUtils::IsMatrix(a_indices->shape())) {
     38     return errors::InvalidArgument(
     39         "Input a_indices should be a matrix but received shape: ",
     40         a_indices->shape().DebugString());
     41   }
     42   if (!TensorShapeUtils::IsVector(a_values->shape()) ||
     43       !TensorShapeUtils::IsVector(a_shape->shape())) {
     44     return errors::InvalidArgument(
     45         "Inputs a_values and a_shape should be vectors "
     46         "but received shapes: ",
     47         a_values->shape().DebugString(), " and ",
     48         a_shape->shape().DebugString());
     49   }
     50   if (a_shape->NumElements() != b->dims()) {
     51     return errors::InvalidArgument(
     52         "Two operands have different ranks; received: ", a_shape->NumElements(),
     53         " and ", b->dims());
     54   }
     55   const auto a_shape_flat = a_shape->flat<Index>();
     56   for (int i = 0; i < b->dims(); ++i) {
     57     if (a_shape_flat(i) != b->dim_size(i)) {
     58       return errors::InvalidArgument(
     59           "Dimension ", i,
     60           " does not equal (no broadcasting is supported): sparse side ",
     61           a_shape_flat(i), " vs dense side ", b->dim_size(i));
     62     }
     63   }
     64   return Status::OK();
     65 }
     66 
     67 }  // namespace
     68 
     69 template <typename Device, typename T, typename Index>
     70 class SparseTensorDenseAddOp : public OpKernel {
     71  public:
     72   explicit SparseTensorDenseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
     73 
     74   void Compute(OpKernelContext *ctx) override {
     75     const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b;
     76     OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices_t));
     77     OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
     78     OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape_t));
     79     OP_REQUIRES_OK(ctx, ctx->input("b", &b));
     80     OP_REQUIRES_OK(
     81         ctx, ValidateInputs<Index>(a_indices_t, a_values_t, a_shape_t, b));
     82 
     83     Tensor *out_t;
     84     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t));
     85 
     86     const int ndims = static_cast<int>(a_indices_t->dim_size(1));
     87     const auto a_indices_mat = a_indices_t->flat_inner_dims<Index>();
     88     const auto a_values_flat = a_values_t->flat<T>();
     89 
     90     switch (ndims) {
     91 #define NDIMS_CASE(N)                                                     \
     92   case N: {                                                               \
     93     auto out_tensor = out_t->tensor<T, N>();                              \
     94     out_tensor.device(ctx->eigen_device<Device>()) = b->tensor<T, N>();   \
     95     const Index result =                                                  \
     96         functor::ScatterNdFunctor<Device, T, Index, N,                    \
     97                                   scatter_op::UpdateOp::ADD>()(           \
     98             ctx->eigen_device<Device>(), a_indices_mat, a_values_flat,    \
     99             out_tensor);                                                  \
    100     OP_REQUIRES(                                                          \
    101         ctx, result == -1,                                                \
    102         errors::InvalidArgument(                                          \
    103             "Sparse tensor has some invalid index on dimension ", result, \
    104             "; dense tensor shape: ", b->shape().DebugString()));         \
    105   } break;
    106 
    107       NDIMS_CASE(1);
    108       NDIMS_CASE(2);
    109       NDIMS_CASE(3);
    110       NDIMS_CASE(4);
    111       NDIMS_CASE(5);
    112       default:
    113         OP_REQUIRES(
    114             ctx, false,
    115             errors::InvalidArgument("Only tensors with ranks between 1 and 5 "
    116                                     "are currently supported.  Tensor rank: ",
    117                                     ndims));
    118 #undef NDIMS_CASE
    119     }
    120   }
    121 };
    122 
    123 namespace functor {
    124 template <typename T, typename Index, int NDIMS>
    125 struct ScatterNdFunctor<CPUDevice, T, Index, NDIMS, scatter_op::UpdateOp::ADD> {
    126   Index operator()(const CPUDevice &d,
    127                    typename TTypes<Index>::ConstMatrix indices,
    128                    typename TTypes<T>::ConstFlat updates,
    129                    typename TTypes<T, NDIMS>::Tensor out) {
    130     Eigen::array<Eigen::DenseIndex, NDIMS> idx;
    131     const int num_nnz = static_cast<int>(indices.dimension(0));
    132     for (int i = 0; i < num_nnz; ++i) {
    133       for (int d = 0; d < NDIMS; ++d) {
    134         idx[d] = internal::SubtleMustCopy(indices(i, d));
    135         if (!FastBoundsCheck(idx[d], out.dimension(d))) {
    136           return d;  // on failure: d nonnegative
    137         }
    138       }
    139       out(idx) += updates(i);
    140     }
    141     return -1;  // on success
    142   }
    143 };
    144 }  // namespace functor
    145 
    146 #define REGISTER_KERNELS_CPU(TypeT, TypeIndex)                        \
    147   REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseAdd")                \
    148                               .Device(DEVICE_CPU)                     \
    149                               .TypeConstraint<TypeT>("T")             \
    150                               .TypeConstraint<TypeIndex>("Tindices"), \
    151                           SparseTensorDenseAddOp<CPUDevice, TypeT, TypeIndex>)
    152 
    153 #define REGISTER_KERNELS(T)       \
    154   REGISTER_KERNELS_CPU(T, int64); \
    155   REGISTER_KERNELS_CPU(T, int32)
    156 
    157 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
    158 #undef REGISTER_KERNELS
    159 #undef REGISTER_KERNELS_CPU
    160 }  // namespace tensorflow
    161