1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/register_types.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/framework/tensor_util.h" 20 #include "tensorflow/core/framework/types.h" 21 #include "tensorflow/core/util/sparse/sparse_tensor.h" 22 23 namespace tensorflow { 24 25 template <typename T, typename Treal> 26 class SparseAddOp : public OpKernel { 27 public: 28 explicit SparseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} 29 30 void Compute(OpKernelContext *ctx) override { 31 // (0) validations 32 const Tensor *a_indices, *b_indices, *a_values_t, *b_values_t, *a_shape, 33 *b_shape, *thresh_t; 34 35 OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices)); 36 OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices)); 37 OP_REQUIRES(ctx, 38 TensorShapeUtils::IsMatrix(a_indices->shape()) && 39 TensorShapeUtils::IsMatrix(b_indices->shape()), 40 errors::InvalidArgument( 41 "Input indices should be matrices but received shapes: ", 42 a_indices->shape().DebugString(), " and ", 43 b_indices->shape().DebugString())); 44 const int64 a_nnz = a_indices->dim_size(0); 45 const int64 b_nnz = b_indices->dim_size(0); 46 47 OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t)); 48 OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t)); 49 50 OP_REQUIRES(ctx, 51 TensorShapeUtils::IsVector(a_values_t->shape()) && 52 TensorShapeUtils::IsVector(b_values_t->shape()), 53 errors::InvalidArgument( 54 "Input values should be vectors but received shapes: ", 55 a_values_t->shape().DebugString(), " and ", 56 b_values_t->shape().DebugString())); 57 auto a_values = ctx->input(1).vec<T>(); 58 auto b_values = ctx->input(4).vec<T>(); 59 OP_REQUIRES( 60 ctx, a_values.size() == a_nnz && b_values.size() == b_nnz, 61 errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz, 62 " non-empty input values, got ", 63 a_values.size(), " and ", b_values.size())); 64 65 OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape)); 66 OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape)); 67 OP_REQUIRES(ctx, 68 TensorShapeUtils::IsVector(a_shape->shape()) && 69 TensorShapeUtils::IsVector(b_shape->shape()), 70 errors::InvalidArgument( 71 "Input shapes should be a vector but received shapes ", 72 a_shape->shape().DebugString(), " and ", 73 b_shape->shape().DebugString())); 74 OP_REQUIRES( 75 ctx, a_shape->IsSameSize(*b_shape), 76 errors::InvalidArgument( 77 "Operands do not have the same ranks; got shapes: ", 78 a_shape->SummarizeValue(10), " and ", b_shape->SummarizeValue(10))); 79 const auto a_shape_flat = a_shape->flat<int64>(); 80 const auto b_shape_flat = b_shape->flat<int64>(); 81 for (int i = 0; i < a_shape->NumElements(); ++i) { 82 OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i), 83 errors::InvalidArgument( 84 "Operands' shapes do not match: got ", a_shape_flat(i), 85 " and ", b_shape_flat(i), " for dimension ", i)); 86 } 87 88 OP_REQUIRES_OK(ctx, ctx->input("thresh", &thresh_t)); 89 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(thresh_t->shape()), 90 errors::InvalidArgument( 91 "The magnitude threshold must be a scalar: got shape ", 92 thresh_t->shape().DebugString())); 93 // std::abs() so that it works for complex{64,128} values as well 94 const Treal thresh = thresh_t->scalar<Treal>()(); 95 96 // (1) do a pass over inputs, and append values and indices to vectors 97 auto a_indices_mat = a_indices->matrix<int64>(); 98 auto b_indices_mat = b_indices->matrix<int64>(); 99 std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx 100 entries_to_copy.reserve(a_nnz + b_nnz); 101 std::vector<T> out_values; 102 const int num_dims = a_shape->dim_size(0); 103 104 // The input and output sparse tensors are assumed to be ordered along 105 // increasing dimension number. 106 int64 i = 0, j = 0; 107 T s; 108 while (i < a_nnz && j < b_nnz) { 109 switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j, 110 num_dims)) { 111 case -1: 112 entries_to_copy.emplace_back(true, i); 113 out_values.push_back(a_values(i)); 114 ++i; 115 break; 116 case 0: 117 s = a_values(i) + b_values(j); 118 if (thresh <= std::abs(s)) { 119 entries_to_copy.emplace_back(true, i); 120 out_values.push_back(s); 121 } 122 ++i; 123 ++j; 124 break; 125 case 1: 126 entries_to_copy.emplace_back(false, j); 127 out_values.push_back(b_values(j)); 128 ++j; 129 break; 130 } 131 } 132 133 #define HANDLE_LEFTOVERS(A_OR_B, IDX, IS_A) \ 134 while (IDX < A_OR_B##_nnz) { \ 135 entries_to_copy.emplace_back(IS_A, IDX); \ 136 out_values.push_back(A_OR_B##_values(IDX)); \ 137 ++IDX; \ 138 } 139 140 // at most one of these calls appends new values 141 HANDLE_LEFTOVERS(a, i, true); 142 HANDLE_LEFTOVERS(b, j, false); 143 #undef HANDLE_LEFTOVERS 144 145 // (2) allocate and fill output tensors 146 const int64 sum_nnz = out_values.size(); 147 Tensor *out_indices_t, *out_values_t; 148 OP_REQUIRES_OK(ctx, 149 ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}), 150 &out_indices_t)); 151 OP_REQUIRES_OK( 152 ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &out_values_t)); 153 auto out_indices_mat = out_indices_t->matrix<int64>(); 154 auto out_values_flat = out_values_t->vec<T>(); 155 156 for (i = 0; i < sum_nnz; ++i) { 157 const bool from_a = entries_to_copy[i].first; 158 const int64 idx = entries_to_copy[i].second; 159 out_indices_mat.chip<0>(i) = 160 from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx); 161 } 162 std::copy_n(out_values.begin(), sum_nnz, &out_values_flat(0)); 163 ctx->set_output(2, *a_shape); 164 } 165 }; 166 167 #define REGISTER_KERNELS(type, thresh_type) \ 168 REGISTER_KERNEL_BUILDER( \ 169 Name("SparseAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 170 SparseAddOp<type, thresh_type>) 171 172 // The list below is equivalent to TF_CALL_REAL_NUMBER_TYPES, minus uint8. This 173 // is because std::abs() on uint8 does not compile. 174 REGISTER_KERNELS(float, float); 175 REGISTER_KERNELS(double, double); 176 REGISTER_KERNELS(int64, int64); 177 REGISTER_KERNELS(int32, int32); 178 REGISTER_KERNELS(int16, int16); 179 REGISTER_KERNELS(int8, int8); 180 REGISTER_KERNELS(complex64, float); 181 REGISTER_KERNELS(complex128, double); 182 #undef REGISTER_KERNELS 183 } // namespace tensorflow 184