Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 #include "tensorflow/core/framework/register_types.h"
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/framework/tensor_util.h"
     20 #include "tensorflow/core/framework/types.h"
     21 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     22 
     23 namespace tensorflow {
     24 
     25 template <typename T, typename Treal>
     26 class SparseAddOp : public OpKernel {
     27  public:
     28   explicit SparseAddOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
     29 
     30   void Compute(OpKernelContext *ctx) override {
     31     // (0) validations
     32     const Tensor *a_indices, *b_indices, *a_values_t, *b_values_t, *a_shape,
     33         *b_shape, *thresh_t;
     34 
     35     OP_REQUIRES_OK(ctx, ctx->input("a_indices", &a_indices));
     36     OP_REQUIRES_OK(ctx, ctx->input("b_indices", &b_indices));
     37     OP_REQUIRES(ctx,
     38                 TensorShapeUtils::IsMatrix(a_indices->shape()) &&
     39                     TensorShapeUtils::IsMatrix(b_indices->shape()),
     40                 errors::InvalidArgument(
     41                     "Input indices should be matrices but received shapes: ",
     42                     a_indices->shape().DebugString(), " and ",
     43                     b_indices->shape().DebugString()));
     44     const int64 a_nnz = a_indices->dim_size(0);
     45     const int64 b_nnz = b_indices->dim_size(0);
     46 
     47     OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
     48     OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
     49 
     50     OP_REQUIRES(ctx,
     51                 TensorShapeUtils::IsVector(a_values_t->shape()) &&
     52                     TensorShapeUtils::IsVector(b_values_t->shape()),
     53                 errors::InvalidArgument(
     54                     "Input values should be vectors but received shapes: ",
     55                     a_values_t->shape().DebugString(), " and ",
     56                     b_values_t->shape().DebugString()));
     57     auto a_values = ctx->input(1).vec<T>();
     58     auto b_values = ctx->input(4).vec<T>();
     59     OP_REQUIRES(
     60         ctx, a_values.size() == a_nnz && b_values.size() == b_nnz,
     61         errors::InvalidArgument("Expected ", a_nnz, " and ", b_nnz,
     62                                 " non-empty input values, got ",
     63                                 a_values.size(), " and ", b_values.size()));
     64 
     65     OP_REQUIRES_OK(ctx, ctx->input("a_shape", &a_shape));
     66     OP_REQUIRES_OK(ctx, ctx->input("b_shape", &b_shape));
     67     OP_REQUIRES(ctx,
     68                 TensorShapeUtils::IsVector(a_shape->shape()) &&
     69                     TensorShapeUtils::IsVector(b_shape->shape()),
     70                 errors::InvalidArgument(
     71                     "Input shapes should be a vector but received shapes ",
     72                     a_shape->shape().DebugString(), " and ",
     73                     b_shape->shape().DebugString()));
     74     OP_REQUIRES(
     75         ctx, a_shape->IsSameSize(*b_shape),
     76         errors::InvalidArgument(
     77             "Operands do not have the same ranks; got shapes: ",
     78             a_shape->SummarizeValue(10), " and ", b_shape->SummarizeValue(10)));
     79     const auto a_shape_flat = a_shape->flat<int64>();
     80     const auto b_shape_flat = b_shape->flat<int64>();
     81     for (int i = 0; i < a_shape->NumElements(); ++i) {
     82       OP_REQUIRES(ctx, a_shape_flat(i) == b_shape_flat(i),
     83                   errors::InvalidArgument(
     84                       "Operands' shapes do not match: got ", a_shape_flat(i),
     85                       " and ", b_shape_flat(i), " for dimension ", i));
     86     }
     87 
     88     OP_REQUIRES_OK(ctx, ctx->input("thresh", &thresh_t));
     89     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(thresh_t->shape()),
     90                 errors::InvalidArgument(
     91                     "The magnitude threshold must be a scalar: got shape ",
     92                     thresh_t->shape().DebugString()));
     93     // std::abs() so that it works for complex{64,128} values as well
     94     const Treal thresh = thresh_t->scalar<Treal>()();
     95 
     96     // (1) do a pass over inputs, and append values and indices to vectors
     97     auto a_indices_mat = a_indices->matrix<int64>();
     98     auto b_indices_mat = b_indices->matrix<int64>();
     99     std::vector<std::pair<bool, int64>> entries_to_copy;  // from_a?, idx
    100     entries_to_copy.reserve(a_nnz + b_nnz);
    101     std::vector<T> out_values;
    102     const int num_dims = a_shape->dim_size(0);
    103 
    104     // The input and output sparse tensors are assumed to be ordered along
    105     // increasing dimension number.
    106     int64 i = 0, j = 0;
    107     T s;
    108     while (i < a_nnz && j < b_nnz) {
    109       switch (sparse::DimComparator::cmp(a_indices_mat, b_indices_mat, i, j,
    110                                          num_dims)) {
    111         case -1:
    112           entries_to_copy.emplace_back(true, i);
    113           out_values.push_back(a_values(i));
    114           ++i;
    115           break;
    116         case 0:
    117           s = a_values(i) + b_values(j);
    118           if (thresh <= std::abs(s)) {
    119             entries_to_copy.emplace_back(true, i);
    120             out_values.push_back(s);
    121           }
    122           ++i;
    123           ++j;
    124           break;
    125         case 1:
    126           entries_to_copy.emplace_back(false, j);
    127           out_values.push_back(b_values(j));
    128           ++j;
    129           break;
    130       }
    131     }
    132 
    133 #define HANDLE_LEFTOVERS(A_OR_B, IDX, IS_A)     \
    134   while (IDX < A_OR_B##_nnz) {                  \
    135     entries_to_copy.emplace_back(IS_A, IDX);    \
    136     out_values.push_back(A_OR_B##_values(IDX)); \
    137     ++IDX;                                      \
    138   }
    139 
    140     // at most one of these calls appends new values
    141     HANDLE_LEFTOVERS(a, i, true);
    142     HANDLE_LEFTOVERS(b, j, false);
    143 #undef HANDLE_LEFTOVERS
    144 
    145     // (2) allocate and fill output tensors
    146     const int64 sum_nnz = out_values.size();
    147     Tensor *out_indices_t, *out_values_t;
    148     OP_REQUIRES_OK(ctx,
    149                    ctx->allocate_output(0, TensorShape({sum_nnz, num_dims}),
    150                                         &out_indices_t));
    151     OP_REQUIRES_OK(
    152         ctx, ctx->allocate_output(1, TensorShape({sum_nnz}), &out_values_t));
    153     auto out_indices_mat = out_indices_t->matrix<int64>();
    154     auto out_values_flat = out_values_t->vec<T>();
    155 
    156     for (i = 0; i < sum_nnz; ++i) {
    157       const bool from_a = entries_to_copy[i].first;
    158       const int64 idx = entries_to_copy[i].second;
    159       out_indices_mat.chip<0>(i) =
    160           from_a ? a_indices_mat.chip<0>(idx) : b_indices_mat.chip<0>(idx);
    161     }
    162     std::copy_n(out_values.begin(), sum_nnz, &out_values_flat(0));
    163     ctx->set_output(2, *a_shape);
    164   }
    165 };
    166 
    167 #define REGISTER_KERNELS(type, thresh_type)                           \
    168   REGISTER_KERNEL_BUILDER(                                            \
    169       Name("SparseAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    170       SparseAddOp<type, thresh_type>)
    171 
    172 // The list below is equivalent to TF_CALL_REAL_NUMBER_TYPES, minus uint8.  This
    173 // is because std::abs() on uint8 does not compile.
    174 REGISTER_KERNELS(float, float);
    175 REGISTER_KERNELS(double, double);
    176 REGISTER_KERNELS(int64, int64);
    177 REGISTER_KERNELS(int32, int32);
    178 REGISTER_KERNELS(int16, int16);
    179 REGISTER_KERNELS(int8, int8);
    180 REGISTER_KERNELS(complex64, float);
    181 REGISTER_KERNELS(complex128, double);
    182 #undef REGISTER_KERNELS
    183 }  // namespace tensorflow
    184