Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
     21 
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/kernels/fill_functor.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 typedef Eigen::GpuDevice GPUDevice;
     31 
     32 template <typename Device, typename T, typename Tindices>
     33 class SparseTensorDenseMatMulOp : public OpKernel {
     34  public:
     35   explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx)
     36       : OpKernel(ctx) {
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a", &adjoint_a_));
     38     OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b", &adjoint_b_));
     39   }
     40 
     41   void Compute(OpKernelContext* ctx) override {
     42     const Tensor* a_indices;
     43     const Tensor* a_values;
     44     const Tensor* a_shape;
     45     const Tensor* b;
     46     OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices));
     47     OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values));
     48     OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
     49     OP_REQUIRES_OK(ctx, ctx->input("b", &b));
     50 
     51     // Check that the dimensions of the two matrices are valid.
     52     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b->shape()),
     53                 errors::InvalidArgument("Tensor 'b' is not a matrix"));
     54 
     55     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()),
     56                 errors::InvalidArgument("Tensor 'a_shape' is not a vector"));
     57 
     58     OP_REQUIRES(
     59         ctx, a_shape->NumElements() == 2,
     60         errors::InvalidArgument("Tensor 'a_shape' must have 2 elements"));
     61 
     62     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values->shape()),
     63                 errors::InvalidArgument("Tensor 'a_values' is not a vector"));
     64 
     65     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()),
     66                 errors::InvalidArgument("Tensor 'a_indices' is not a matrix"));
     67 
     68     const int64 nnz = a_indices->shape().dim_size(0);
     69     OP_REQUIRES(ctx, nnz == a_values->NumElements(),
     70                 errors::InvalidArgument("Number of rows of a_indices does not "
     71                                         "match number of entries in a_values"));
     72 
     73     OP_REQUIRES(
     74         ctx, a_indices->shape().dim_size(1) == a_shape->NumElements(),
     75         errors::InvalidArgument("Number of columns of a_indices does not match "
     76                                 "number of entries in a_shape"));
     77 
     78     auto a_shape_t = a_shape->vec<int64>();
     79     const int64 outer_left = (adjoint_a_) ? a_shape_t(1) : a_shape_t(0);
     80     const int64 outer_right =
     81         (adjoint_b_) ? b->shape().dim_size(0) : b->shape().dim_size(1);
     82     const int64 inner_left = (adjoint_a_) ? a_shape_t(0) : a_shape_t(1);
     83     const int64 inner_right =
     84         (adjoint_b_) ? b->shape().dim_size(1) : b->shape().dim_size(0);
     85 
     86     OP_REQUIRES(
     87         ctx, inner_right == inner_left,
     88         errors::InvalidArgument(
     89             "Cannot multiply A and B because inner dimension does not match: ",
     90             inner_left, " vs. ", inner_right,
     91             ".  Did you forget a transpose?  "
     92             "Dimensions of A: [",
     93             a_shape_t(0), ", ", a_shape_t(1),
     94             ").  Dimensions of B: ", b->shape().DebugString()));
     95 
     96     if (std::is_same<Device, GPUDevice>::value) {
     97       // The GPU implementation is optimized to use 32 bit indexing, so
     98       // give a friendly error to the programmer early on if they
     99       // exceed.
    100       const int int32max = std::numeric_limits<int>::max();
    101       OP_REQUIRES(
    102           ctx,
    103           (FastBoundsCheck(inner_left, int32max) &&
    104            FastBoundsCheck(inner_right, int32max) &&
    105            FastBoundsCheck(outer_left, int32max) &&
    106            FastBoundsCheck(outer_right, int32max) &&
    107            FastBoundsCheck(b->NumElements(), int32max) &&
    108            FastBoundsCheck(outer_left * outer_right, int32max) &&
    109            FastBoundsCheck(a_values->NumElements(), int32max)),
    110           errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs"));
    111       OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max),
    112                   errors::InvalidArgument(
    113                       "Cannot use GPU when output.shape[1] * nnz(a) > 2^31"));
    114     }
    115 
    116     TensorShape out_shape({outer_left, outer_right});
    117     Tensor* out = nullptr;
    118     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
    119 
    120     if (out->NumElements() == 0) {
    121       // If a has shape [0, x] or b has shape [x, 0], the output shape
    122       // is a 0-element matrix, so there is nothing to do.
    123       return;
    124     }
    125 
    126     if (a_values->NumElements() == 0 || b->NumElements() == 0) {
    127       // If a has shape [x, 0] and b has shape [0, y], the
    128       // output shape is [x, y] where x and y are non-zero, so we fill
    129       // the output with zeros.
    130       functor::SetZeroFunctor<Device, T> f;
    131       f(ctx->eigen_device<Device>(), out->flat<T>());
    132       return;
    133     }
    134 
    135 #define MAYBE_ADJOINT(ADJ_A, ADJ_B)                                        \
    136   if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                        \
    137     Status functor_status = functor::SparseTensorDenseMatMulFunctor<       \
    138         Device, T, Tindices, ADJ_A,                                        \
    139         ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(),     \
    140                         a_indices->matrix<Tindices>(), a_values->vec<T>(), \
    141                         b->matrix<T>());                                   \
    142     OP_REQUIRES_OK(ctx, functor_status);                                   \
    143   }
    144 
    145     MAYBE_ADJOINT(false, false);
    146     MAYBE_ADJOINT(false, true);
    147     MAYBE_ADJOINT(true, false);
    148     MAYBE_ADJOINT(true, true);
    149 
    150 #undef MAYBE_ADJOINT
    151   }
    152 
    153  private:
    154   bool adjoint_a_;
    155   bool adjoint_b_;
    156 };
    157 
    158 #define REGISTER_CPU(TypeT, TypeIndex)           \
    159   REGISTER_KERNEL_BUILDER(                       \
    160       Name("SparseTensorDenseMatMul")            \
    161           .Device(DEVICE_CPU)                    \
    162           .TypeConstraint<TypeT>("T")            \
    163           .TypeConstraint<TypeIndex>("Tindices") \
    164           .HostMemory("a_shape"),                \
    165       SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>);
    166 
    167 #define REGISTER_KERNELS_CPU(T) \
    168   REGISTER_CPU(T, int64);       \
    169   REGISTER_CPU(T, int32)
    170 
    171 REGISTER_KERNELS_CPU(float);
    172 REGISTER_KERNELS_CPU(double);
    173 REGISTER_KERNELS_CPU(int32);
    174 REGISTER_KERNELS_CPU(complex64);
    175 REGISTER_KERNELS_CPU(complex128);
    176 
    177 #if GOOGLE_CUDA
    178 
    179 namespace functor {
    180 #define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B)                       \
    181   template <>                                                             \
    182   Status SparseTensorDenseMatMulFunctor<                                  \
    183       GPUDevice, T, Tindices, ADJ_A,                                      \
    184       ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
    185                       TTypes<Tindices>::ConstMatrix a_indices,            \
    186                       typename TTypes<T>::ConstVec a_values,              \
    187                       typename TTypes<T>::ConstMatrix b);                 \
    188   extern template struct SparseTensorDenseMatMulFunctor<                  \
    189       GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
    190 
    191 #define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B)  \
    192   DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \
    193   DECLARE_GPU_SPEC(T, int64, ADJ_A, ADJ_B)
    194 
    195 #define DECLARE_ADJOINT_GPU_SPEC(T)  \
    196   REGISTER_GPU_SPEC(T, false, false) \
    197   REGISTER_GPU_SPEC(T, false, true)  \
    198   REGISTER_GPU_SPEC(T, true, false)  \
    199   REGISTER_GPU_SPEC(T, true, true)
    200 
    201 DECLARE_ADJOINT_GPU_SPEC(float);
    202 #undef DECLARE_ADJOINT_GPU_SPEC
    203 #undef DECLARE_GPU_SPEC
    204 #undef REGISTER_GPU_SPEC
    205 
    206 }  // namespace functor
    207 
    208 #define REGISTER_GPU(TypeT, TypeIndex)           \
    209   REGISTER_KERNEL_BUILDER(                       \
    210       Name("SparseTensorDenseMatMul")            \
    211           .Device(DEVICE_GPU)                    \
    212           .TypeConstraint<TypeT>("T")            \
    213           .TypeConstraint<TypeIndex>("Tindices") \
    214           .HostMemory("a_shape"),                \
    215       SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>);
    216 
    217 #define REGISTER_KERNELS_GPU(T) \
    218   REGISTER_GPU(T, int64);       \
    219   REGISTER_GPU(T, int32)
    220 
    221 REGISTER_KERNELS_GPU(float);
    222 #undef REGISTER_GPU
    223 #undef REGISTER_KERNELS_GPU
    224 #endif  // GOOGLE_CUDA
    225 
    226 namespace functor {
    227 
    228 namespace {
    229 Status KOutOfBoundsError(int64 k, std::size_t i, int rhs_index_a,
    230                          std::size_t lhs_right) {
    231   return errors::InvalidArgument("k (", k, ") from index[", i, ",", rhs_index_a,
    232                                  "] out of bounds (>=", lhs_right, ")");
    233 }
    234 
    235 Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a,
    236                          int64 out_dim0) {
    237   return errors::InvalidArgument("m (", m, ") from index[", i, ",", lhs_index_a,
    238                                  "] out of bounds (>=", out_dim0, ")");
    239 }
    240 }  // namespace
    241 
    242 template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
    243 struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
    244   // Vectorize certain operations above this size.
    245   static const std::size_t kNumVectorize = 32;
    246 
    247   static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
    248                         typename TTypes<Tindices>::ConstMatrix a_indices,
    249                         typename TTypes<T>::ConstVec a_values,
    250                         typename TTypes<T>::ConstMatrix b) {
    251     const std::size_t nnz = a_values.size();
    252     const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
    253     const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
    254     const int lhs_index_a = ADJ_A ? 1 : 0;
    255     const int rhs_index_a = ADJ_A ? 0 : 1;
    256 
    257     out.setZero();
    258 
    259     // TODO(ebrevdo): After many failed experiments, can't find a multi-threaded
    260     // approach that achieves the performance of the single threaded
    261     // one.  Perhaps Eigen threadpool implementation is just too slow?
    262 
    263     if (rhs_right < kNumVectorize) {
    264       // Disable vectorization if the RHS of output is too small
    265       auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
    266 
    267       for (std::size_t i = 0; i < nnz; ++i) {
    268         const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
    269         const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
    270         if (!FastBoundsCheck(k, lhs_right)) {
    271           return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
    272         }
    273         if (!FastBoundsCheck(m, out.dimension(0))) {
    274           return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
    275         }
    276         const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
    277         for (std::size_t n = 0; n < rhs_right; ++n) {
    278           const T b_value = maybe_adjoint_b(k, n);
    279           out(m, n) += a_value * b_value;
    280         }
    281       }
    282     } else {
    283       // Vectorization via Eigen.
    284       const int b_chip_index = ADJ_B ? 1 : 0;
    285 
    286 #define LOOP_NNZ(b_passed)                                                  \
    287   for (std::size_t i = 0; i < nnz; ++i) {                                   \
    288     const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
    289     const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
    290     const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i);       \
    291     if (!FastBoundsCheck(k, lhs_right)) {                                   \
    292       return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);               \
    293     }                                                                       \
    294     if (!FastBoundsCheck(m, out.dimension(0))) {                            \
    295       return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));        \
    296     }                                                                       \
    297     out.template chip<0>(m) +=                                              \
    298         b_passed.template chip<b_chip_index>(k) * a_value;                  \
    299   }
    300 
    301       if (ADJ_B) {
    302         // Perform transpose and conjugation on B once, since we chip out B's
    303         // columns in the nnz loop.
    304         Eigen::array<int, 2> shuffle(1, 0);  // preserve dimension order
    305         Eigen::Tensor<T, 2, Eigen::ColMajor> col_major_conj_b =
    306             b.swap_layout().shuffle(shuffle).conjugate();
    307         LOOP_NNZ(col_major_conj_b);
    308       } else {
    309         LOOP_NNZ(b);
    310       }
    311 #undef LOOP_NNZ
    312     }
    313     return Status::OK();
    314   }
    315 };
    316 
    317 }  // namespace functor
    318 
    319 }  // namespace tensorflow
    320