Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <numeric>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_util.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     30 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     31 
     32 namespace tensorflow {
     33 
     34 template <typename T>
     35 class SparseConcatOp : public OpKernel {
     36  public:
     37   explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
     38     OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_));
     39   }
     40 
     41   void Compute(OpKernelContext* context) override {
     42     OpInputList inds;
     43     OP_REQUIRES_OK(context, context->input_list("indices", &inds));
     44     const int N = inds.size();
     45     for (int i = 0; i < N; i++) {
     46       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
     47                   errors::InvalidArgument(
     48                       "Input indices should be a matrix but received shape ",
     49                       inds[i].shape().DebugString(), " at position ", i));
     50     }
     51 
     52     OpInputList vals;
     53     OP_REQUIRES_OK(context, context->input_list("values", &vals));
     54     OP_REQUIRES(context, vals.size() == N,
     55                 errors::InvalidArgument("Expected ", N, " input values, got ",
     56                                         vals.size()));
     57     for (int i = 0; i < N; i++) {
     58       OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
     59                   errors::InvalidArgument(
     60                       "Input values should be a vector but received shape ",
     61                       vals[i].shape().DebugString(), " at position ", i));
     62     }
     63 
     64     OpInputList shapes;
     65     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
     66     OP_REQUIRES(context, shapes.size() == N,
     67                 errors::InvalidArgument("Expected ", N, " input shapes, got ",
     68                                         shapes.size()));
     69     for (int i = 0; i < N; i++) {
     70       OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
     71                   errors::InvalidArgument(
     72                       "Input shapes should be a vector but received shape ",
     73                       shapes[i].shape().DebugString(), " at position ", i));
     74     }
     75 
     76     const TensorShape input_shape(shapes[0].vec<int64>());
     77     const int input_rank = input_shape.dims();
     78     const int concat_dim = (concat_dim_attr_ < 0)
     79                                ? input_rank + concat_dim_attr_
     80                                : concat_dim_attr_;
     81     OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank,
     82                 errors::InvalidArgument("Concat dimension must be in range [",
     83                                         -input_rank, ", ", input_rank,
     84                                         "), got ", concat_dim_attr_));
     85     for (int i = 1; i < N; ++i) {
     86       const TensorShape current_shape(shapes[i].vec<int64>());
     87       OP_REQUIRES(
     88           context, current_shape.dims() == input_rank,
     89           errors::InvalidArgument(
     90               "Ranks of all input tensors must match: expected ", input_rank,
     91               " but got ", current_shape.dims(), " at position ", i));
     92       for (int j = 0; j < input_rank; ++j) {
     93         if (j != concat_dim) {
     94           OP_REQUIRES(
     95               context, input_shape.dim_size(j) == current_shape.dim_size(j),
     96               errors::InvalidArgument(
     97                   "Input shapes must match: expected ", input_shape.dim_size(j),
     98                   " for dimension ", j, " but got ", current_shape.dim_size(j),
     99                   " at position ", i));
    100         }
    101       }
    102     }
    103 
    104     // The input and output sparse tensors are assumed to be ordered along
    105     // increasing dimension number. But in order for concat to work properly,
    106     // order[0] must be concat_dim. So we will reorder the inputs to the
    107     // concat ordering, concatenate, then reorder back to the standard order.
    108     // We make a deep copy of the input tensors to ensure that the in-place
    109     // reorder doesn't create race conditions for other ops that may be
    110     // concurrently reading the indices and values tensors.
    111 
    112     gtl::InlinedVector<int64, 8> std_order(input_rank);
    113     std::iota(std_order.begin(), std_order.end(), 0);
    114 
    115     std::vector<int64> concat_order;
    116     concat_order.reserve(input_rank);
    117     concat_order.push_back(concat_dim);
    118     for (int j = 0; j < input_rank; ++j) {
    119       if (j != concat_dim) {
    120         concat_order.push_back(j);
    121       }
    122     }
    123 
    124     std::vector<sparse::SparseTensor> sp_inputs;
    125     for (int i = 0; i < N; ++i) {
    126       const TensorShape current_shape(shapes[i].vec<int64>());
    127       sp_inputs.emplace_back(tensor::DeepCopy(inds[i]),
    128                              tensor::DeepCopy(vals[i]), current_shape,
    129                              std_order);
    130       sp_inputs[i].Reorder<T>(concat_order);
    131     }
    132 
    133     sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
    134     concat.Reorder<T>(std_order);
    135 
    136     context->set_output(0, concat.indices());
    137     context->set_output(1, concat.values());
    138 
    139     Tensor* output_shape_out = nullptr;
    140     OP_REQUIRES_OK(context,
    141                    context->allocate_output(2, TensorShape({concat.dims()}),
    142                                             &output_shape_out));
    143     auto output_shape = output_shape_out->vec<int64>();
    144     auto concat_shape = concat.shape();
    145     for (int j = 0; j < concat.dims(); ++j) {
    146       output_shape(j) = concat_shape[j];
    147     }
    148   }
    149 
    150  private:
    151   int concat_dim_attr_;
    152 };
    153 
    154 #define REGISTER_KERNELS(type)                                           \
    155   REGISTER_KERNEL_BUILDER(                                               \
    156       Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    157       SparseConcatOp<type>)
    158 
    159 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    160 #undef REGISTER_KERNELS
    161 }  // namespace tensorflow
    162