Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // SparseDenseBinaryOpShared is the shared code for binary coefficient-wise
     17 // (cwise) operations of the following form:
     18 //
     19 //   sparse_t <binary cwise op> dense_t -> new sparse_t
     20 //
     21 // where:
     22 //
     23 //   (1) "binary cwise op" can be, for example, cdiv, cmul, cfloordiv, etc.
     24 //   (2) LIMITATION: we only support broadcasting the dense side to the sparse
     25 //       side.  In other words, NumDims(sparse_t) >= NumDims(dense_t), and if
     26 //       they are equal, each dim size of sparse_t >= that of dense_t.
     27 //   (3) Note that the result is a new sparse tensor, which means the implicitly
     28 //       zero elements of sparse_t do not participate.  (Hence, this should not
     29 //       be used for, say, cadd.)
     30 //
     31 // The only output is a vector of flat values with shape [nnz], since this op
     32 // does not change neither the indices nor the shape of the sparse operand.
     33 //
     34 // See docs of all registered ops in ../ops/sparse_ops.cc.
     35 
     36 #define EIGEN_USE_THREADS
     37 
     38 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     39 #include "tensorflow/core/framework/op_kernel.h"
     40 #include "tensorflow/core/framework/register_types.h"
     41 #include "tensorflow/core/framework/tensor.h"
     42 #include "tensorflow/core/framework/tensor_util.h"
     43 #include "tensorflow/core/framework/types.h"
     44 #include "tensorflow/core/kernels/cwise_ops.h"
     45 #include "tensorflow/core/kernels/cwise_ops_common.h"
     46 #include "tensorflow/core/util/bcast.h"
     47 
     48 using Eigen::TensorRef;
     49 using tensorflow::gtl::ArraySlice;
     50 
     51 namespace tensorflow {
     52 
     53 typedef Eigen::ThreadPoolDevice CPUDevice;
     54 
     55 template <typename Device, typename T, typename Functor>
     56 class SparseDenseBinaryOpShared : public OpKernel {
     57  public:
     58   explicit SparseDenseBinaryOpShared(OpKernelConstruction *ctx)
     59       : OpKernel(ctx) {}
     60 
     61   void Compute(OpKernelContext *ctx) override {
     62     const Tensor *indices_t, *values_t, *shape_t, *dense_t;
     63     OP_REQUIRES_OK(ctx, ctx->input("sp_indices", &indices_t));
     64     OP_REQUIRES_OK(ctx, ctx->input("sp_values", &values_t));
     65     OP_REQUIRES_OK(ctx, ctx->input("sp_shape", &shape_t));
     66     OP_REQUIRES_OK(ctx, ctx->input("dense", &dense_t));
     67 
     68     // Validations.
     69     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices_t->shape()),
     70                 errors::InvalidArgument(
     71                     "Input sp_indices should be a matrix but received shape: ",
     72                     indices_t->shape().DebugString()));
     73     OP_REQUIRES(ctx,
     74                 TensorShapeUtils::IsVector(values_t->shape()) &&
     75                     TensorShapeUtils::IsVector(shape_t->shape()),
     76                 errors::InvalidArgument(
     77                     "Inputs sp_values and sp_shape should be vectors "
     78                     "but received shapes: ",
     79                     values_t->shape().DebugString(), " and ",
     80                     shape_t->shape().DebugString()));
     81     OP_REQUIRES(ctx, indices_t->dim_size(0) < std::numeric_limits<int>::max(),
     82                 errors::InvalidArgument(
     83                     "Number of non-zero elements exceeds int32 range"));
     84 
     85     const auto indices_mat = indices_t->matrix<int64>();
     86     const auto shape_vec = shape_t->vec<int64>();
     87     const auto lhs_dims = BCast::FromShape(TensorShape(shape_vec));
     88     const auto rhs_dims = BCast::FromShape(dense_t->shape());
     89     BCast b(lhs_dims, rhs_dims, false);  // false for keeping the same num dims.
     90 
     91     // True iff (size(lhs) > size(rhs)), or (sizes equal, lhs cwise rhs).
     92     auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
     93       if (lhs.size() > rhs.size()) return true;
     94       if (lhs.size() < rhs.size()) return false;
     95       for (size_t i = 0; i < lhs.size(); ++i) {
     96         if (lhs[i] < rhs[i]) return false;
     97       }
     98       return true;
     99     };
    100     OP_REQUIRES(ctx, VecGreaterEq(lhs_dims, rhs_dims) && b.IsValid(),
    101                 errors::InvalidArgument(
    102                     "SparseDenseBinaryOpShared broadcasts dense to sparse "
    103                     "only; got incompatible shapes: [",
    104                     str_util::Join(lhs_dims, ","), "] vs. [",
    105                     str_util::Join(rhs_dims, ","), "]"));
    106 
    107     Tensor *output_values = nullptr;
    108     Tensor dense_gathered;
    109     const int nnz = static_cast<int>(indices_t->dim_size(0));
    110     OP_REQUIRES_OK(ctx,
    111                    ctx->allocate_output(0, TensorShape({nnz}), &output_values));
    112     OP_REQUIRES_OK(
    113         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, TensorShape({nnz}),
    114                                 &dense_gathered));
    115 
    116     // Pulls relevant entries from the dense side, with reshape and broadcasting
    117     // *of the dense side* taken into account.  Use a TensorRef to avoid blowing
    118     // up memory.
    119     //
    120     // We can directly use the sparse indices to look up dense side, because
    121     // "b.y_reshape()" and "b.y_bcast()" are guaranteed to have rank "ndims".
    122     auto dense_gathered_flat = dense_gathered.flat<T>();
    123     const int ndims = lhs_dims.size();
    124     switch (ndims) {
    125 #define CASE(NDIM)                                                             \
    126   case NDIM: {                                                                 \
    127     TensorRef<Eigen::Tensor<const T, NDIM, Eigen::RowMajor>> rhs_ref =         \
    128         dense_t->shaped<T, NDIM>(b.y_reshape())                                \
    129             .broadcast(BCast::ToIndexArray<NDIM>(b.y_bcast()));                \
    130     Eigen::array<Eigen::DenseIndex, NDIM> idx;                                 \
    131     bool indices_valid = true;                                                 \
    132     for (int i = 0; i < nnz; ++i) {                                            \
    133       for (int d = 0; d < NDIM; ++d) {                                         \
    134         idx[d] = internal::SubtleMustCopy(indices_mat(i, d));                  \
    135         if (!FastBoundsCheck(idx[d], rhs_ref.dimension(d))) {                  \
    136           indices_valid = false;                                               \
    137         }                                                                      \
    138       }                                                                        \
    139       OP_REQUIRES(                                                             \
    140           ctx, indices_valid,                                                  \
    141           errors::InvalidArgument("Provided indices are out-of-bounds w.r.t. " \
    142                                   "dense side with broadcasted shape"));       \
    143       dense_gathered_flat(i) = rhs_ref.coeff(idx);                             \
    144     }                                                                          \
    145     break;                                                                     \
    146   }
    147 
    148       CASE(1);
    149       CASE(2);
    150       CASE(3);
    151       CASE(4);
    152       CASE(5);
    153       default:
    154         OP_REQUIRES(
    155             ctx, false,
    156             errors::InvalidArgument("Only tensors with ranks between 1 and 5 "
    157                                     "are currently supported.  Tensor rank: ",
    158                                     ndims));
    159 #undef CASE
    160     }
    161 
    162     output_values->flat<T>().device(ctx->eigen_device<Device>()) =
    163         values_t->flat<T>().binaryExpr(dense_gathered_flat,
    164                                        typename Functor::func());
    165   }
    166 };
    167 
    168 // NOTE(aselle): If Div is extended to non-reals, make sure to use the same
    169 // separation of operator semantics as done for dense cwise ops. I.e. you
    170 // should make SparseDenseCwiseRealDiv, SparseDenseCwiseTruncateDiv,
    171 // SparseDenseCwiseFloorDiv, and then deprecate, SparseDenseCwiseDiv.
    172 // TODO(zongheng): extend to other eligible cwise operations as requested.
    173 #define REGISTER_KERNELS(T)                                                  \
    174   REGISTER_KERNEL_BUILDER(                                                   \
    175       Name("SparseDenseCwiseMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    176       SparseDenseBinaryOpShared<CPUDevice, T, functor::mul<T>>)              \
    177                                                                              \
    178   REGISTER_KERNEL_BUILDER(                                                   \
    179       Name("SparseDenseCwiseDiv").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    180       SparseDenseBinaryOpShared<CPUDevice, T, functor::div<T>>)              \
    181   REGISTER_KERNEL_BUILDER(                                                   \
    182       Name("SparseDenseCwiseAdd").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    183       SparseDenseBinaryOpShared<CPUDevice, T, functor::add<T>>)
    184 
    185 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
    186 #undef REGISTER_KERNELS
    187 
    188 }  // namespace tensorflow
    189