1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/sparse_tensor_dense_add_op.h" 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_util.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/util/sparse/sparse_tensor.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 // NOTE: does not support GPU yet. 31 32 namespace { 33 34 template <typename Index> 35 Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values, 36 const Tensor *a_shape, const Tensor *b) { 37 if (!TensorShapeUtils::IsMatrix(a_indices->shape())) { 38 return errors::InvalidArgument( 39 "Input a_indices should be a matrix but received shape: ", 40 a_indices->shape().DebugString()); 41 } 42 if (!TensorShapeUtils::IsVector(a_values->shape()) || 43 !TensorShapeUtils::IsVector(a_shape->shape())) { 44 return errors::InvalidArgument( 45 "Inputs a_values and a_shape should be vectors " 46 "but received shapes: ", 47 a_values->shape().DebugString(), " and ", 48 a_shape->shape().DebugString()); 49 } 50 if (a_shape->NumElements() != b->dims()) { 51 return errors::InvalidArgument( 52 "Two operands have different ranks; received: ", a_shape->NumElements(), 53 " and ", b->dims()); 54 } 55 const auto a_shape_flat = a_shape->flat<Index>(); 56 for (int i = 0; i < b->dims(); ++i) { 57 if (a_shape_flat(i) != b->dim_size(i)) { 58 return errors::InvalidArgument( 59 "Dimension ", i, 60 " does not equal (no broadcasting is supported): sparse side ", 61 a_shape_flat(i), " vs dense side ", b->dim_size(i)); 62 } 63 } 64 return Status::OK(); 65 } 66 67 } // namespace 68 69 template <typename Device, typename T, typename Index> 70 class SparseTensorDenseAddOp : public OpKernel { 71 public: 72 explicit SparseTensorDenseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} 73 74 void Compute(OpKernelContext *ctx) override { 75 const Tensor *a_indices_t, *a_values_t, *a_shape_t, *b; 76 OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices_t)); 77 OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t)); 78 OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape_t)); 79 OP_REQUIRES_OK(ctx, ctx->input("b", &b)); 80 OP_REQUIRES_OK( 81 ctx, ValidateInputs<Index>(a_indices_t, a_values_t, a_shape_t, b)); 82 83 Tensor *out_t; 84 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t)); 85 86 const int ndims = static_cast<int>(a_indices_t->dim_size(1)); 87 const auto a_indices_mat = a_indices_t->flat_inner_dims<Index>(); 88 const auto a_values_flat = a_values_t->flat<T>(); 89 90 switch (ndims) { 91 #define NDIMS_CASE(N) \ 92 case N: { \ 93 auto out_tensor = out_t->tensor<T, N>(); \ 94 out_tensor.device(ctx->eigen_device<Device>()) = b->tensor<T, N>(); \ 95 const Index result = \ 96 functor::ScatterNdFunctor<Device, T, Index, N, \ 97 scatter_op::UpdateOp::ADD>()( \ 98 ctx->eigen_device<Device>(), a_indices_mat, a_values_flat, \ 99 out_tensor); \ 100 OP_REQUIRES( \ 101 ctx, result == -1, \ 102 errors::InvalidArgument( \ 103 "Sparse tensor has some invalid index on dimension ", result, \ 104 "; dense tensor shape: ", b->shape().DebugString())); \ 105 } break; 106 107 NDIMS_CASE(1); 108 NDIMS_CASE(2); 109 NDIMS_CASE(3); 110 NDIMS_CASE(4); 111 NDIMS_CASE(5); 112 default: 113 OP_REQUIRES( 114 ctx, false, 115 errors::InvalidArgument("Only tensors with ranks between 1 and 5 " 116 "are currently supported. Tensor rank: ", 117 ndims)); 118 #undef NDIMS_CASE 119 } 120 } 121 }; 122 123 namespace functor { 124 template <typename T, typename Index, int NDIMS> 125 struct ScatterNdFunctor<CPUDevice, T, Index, NDIMS, scatter_op::UpdateOp::ADD> { 126 Index operator()(const CPUDevice &d, 127 typename TTypes<Index>::ConstMatrix indices, 128 typename TTypes<T>::ConstFlat updates, 129 typename TTypes<T, NDIMS>::Tensor out) { 130 Eigen::array<Eigen::DenseIndex, NDIMS> idx; 131 const int num_nnz = static_cast<int>(indices.dimension(0)); 132 for (int i = 0; i < num_nnz; ++i) { 133 for (int d = 0; d < NDIMS; ++d) { 134 idx[d] = internal::SubtleMustCopy(indices(i, d)); 135 if (!FastBoundsCheck(idx[d], out.dimension(d))) { 136 return d; // on failure: d nonnegative 137 } 138 } 139 out(idx) += updates(i); 140 } 141 return -1; // on success 142 } 143 }; 144 } // namespace functor 145 146 #define REGISTER_KERNELS_CPU(TypeT, TypeIndex) \ 147 REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseAdd") \ 148 .Device(DEVICE_CPU) \ 149 .TypeConstraint<TypeT>("T") \ 150 .TypeConstraint<TypeIndex>("Tindices"), \ 151 SparseTensorDenseAddOp<CPUDevice, TypeT, TypeIndex>) 152 153 #define REGISTER_KERNELS(T) \ 154 REGISTER_KERNELS_CPU(T, int64); \ 155 REGISTER_KERNELS_CPU(T, int32) 156 157 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 158 #undef REGISTER_KERNELS 159 #undef REGISTER_KERNELS_CPU 160 } // namespace tensorflow 161