1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_util.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/gtl/inlined_vector.h" 30 #include "tensorflow/core/util/sparse/sparse_tensor.h" 31 32 namespace tensorflow { 33 34 template <typename T> 35 class SparseConcatOp : public OpKernel { 36 public: 37 explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) { 38 OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_)); 39 } 40 41 void Compute(OpKernelContext* context) override { 42 OpInputList inds; 43 OP_REQUIRES_OK(context, context->input_list("indices", &inds)); 44 const int N = inds.size(); 45 for (int i = 0; i < N; i++) { 46 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()), 47 errors::InvalidArgument( 48 "Input indices should be a matrix but received shape ", 49 inds[i].shape().DebugString(), " at position ", i)); 50 } 51 52 OpInputList vals; 53 OP_REQUIRES_OK(context, context->input_list("values", &vals)); 54 OP_REQUIRES(context, vals.size() == N, 55 errors::InvalidArgument("Expected ", N, " input values, got ", 56 vals.size())); 57 for (int i = 0; i < N; i++) { 58 OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()), 59 errors::InvalidArgument( 60 "Input values should be a vector but received shape ", 61 vals[i].shape().DebugString(), " at position ", i)); 62 } 63 64 OpInputList shapes; 65 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes)); 66 OP_REQUIRES(context, shapes.size() == N, 67 errors::InvalidArgument("Expected ", N, " input shapes, got ", 68 shapes.size())); 69 for (int i = 0; i < N; i++) { 70 OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()), 71 errors::InvalidArgument( 72 "Input shapes should be a vector but received shape ", 73 shapes[i].shape().DebugString(), " at position ", i)); 74 } 75 76 const TensorShape input_shape(shapes[0].vec<int64>()); 77 const int input_rank = input_shape.dims(); 78 const int concat_dim = (concat_dim_attr_ < 0) 79 ? input_rank + concat_dim_attr_ 80 : concat_dim_attr_; 81 OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank, 82 errors::InvalidArgument("Concat dimension must be in range [", 83 -input_rank, ", ", input_rank, 84 "), got ", concat_dim_attr_)); 85 for (int i = 1; i < N; ++i) { 86 const TensorShape current_shape(shapes[i].vec<int64>()); 87 OP_REQUIRES( 88 context, current_shape.dims() == input_rank, 89 errors::InvalidArgument( 90 "Ranks of all input tensors must match: expected ", input_rank, 91 " but got ", current_shape.dims(), " at position ", i)); 92 for (int j = 0; j < input_rank; ++j) { 93 if (j != concat_dim) { 94 OP_REQUIRES( 95 context, input_shape.dim_size(j) == current_shape.dim_size(j), 96 errors::InvalidArgument( 97 "Input shapes must match: expected ", input_shape.dim_size(j), 98 " for dimension ", j, " but got ", current_shape.dim_size(j), 99 " at position ", i)); 100 } 101 } 102 } 103 104 // The input and output sparse tensors are assumed to be ordered along 105 // increasing dimension number. But in order for concat to work properly, 106 // order[0] must be concat_dim. So we will reorder the inputs to the 107 // concat ordering, concatenate, then reorder back to the standard order. 108 // We make a deep copy of the input tensors to ensure that the in-place 109 // reorder doesn't create race conditions for other ops that may be 110 // concurrently reading the indices and values tensors. 111 112 gtl::InlinedVector<int64, 8> std_order(input_rank); 113 std::iota(std_order.begin(), std_order.end(), 0); 114 115 std::vector<int64> concat_order; 116 concat_order.reserve(input_rank); 117 concat_order.push_back(concat_dim); 118 for (int j = 0; j < input_rank; ++j) { 119 if (j != concat_dim) { 120 concat_order.push_back(j); 121 } 122 } 123 124 std::vector<sparse::SparseTensor> sp_inputs; 125 for (int i = 0; i < N; ++i) { 126 const TensorShape current_shape(shapes[i].vec<int64>()); 127 sp_inputs.emplace_back(tensor::DeepCopy(inds[i]), 128 tensor::DeepCopy(vals[i]), current_shape, 129 std_order); 130 sp_inputs[i].Reorder<T>(concat_order); 131 } 132 133 sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs); 134 concat.Reorder<T>(std_order); 135 136 context->set_output(0, concat.indices()); 137 context->set_output(1, concat.values()); 138 139 Tensor* output_shape_out = nullptr; 140 OP_REQUIRES_OK(context, 141 context->allocate_output(2, TensorShape({concat.dims()}), 142 &output_shape_out)); 143 auto output_shape = output_shape_out->vec<int64>(); 144 auto concat_shape = concat.shape(); 145 for (int j = 0; j < concat.dims(); ++j) { 146 output_shape(j) = concat_shape[j]; 147 } 148 } 149 150 private: 151 int concat_dim_attr_; 152 }; 153 154 #define REGISTER_KERNELS(type) \ 155 REGISTER_KERNEL_BUILDER( \ 156 Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 157 SparseConcatOp<type>) 158 159 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 160 #undef REGISTER_KERNELS 161 } // namespace tensorflow 162