1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See core/ops/sparse_ops.cc for documentation. 17 // 18 // NOTE: the operations in this file only are suitable for execution 19 // on CPUs. 20 21 #define EIGEN_USE_THREADS 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 25 #include <numeric> 26 #include <sstream> 27 #include <string> 28 #include <unordered_map> 29 #include <utility> 30 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/lib/strings/stringprintf.h" 38 #include "tensorflow/core/util/sparse/sparse_tensor.h" 39 40 namespace tensorflow { 41 42 // Operator to convert sparse representations to dense. 43 template <typename T, typename Index> 44 class SparseToDense : public OpKernel { 45 public: 46 explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) { 47 OP_REQUIRES_OK(context, 48 context->GetAttr("validate_indices", &validate_indices_)); 49 } 50 51 void Compute(OpKernelContext* c) override { 52 // sparse_indices 53 const Tensor& indices = c->input(0); 54 OP_REQUIRES(c, indices.dims() <= 2, 55 errors::InvalidArgument( 56 "sparse_indices should be a scalar, vector, or matrix, " 57 "got shape ", 58 indices.shape().DebugString())); 59 const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1; 60 const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1; 61 62 // output_shape 63 const Tensor& output_shape = c->input(1); 64 OP_REQUIRES( 65 c, IsLegacyVector(output_shape.shape()), 66 errors::InvalidArgument("output_shape should be a vector, got shape ", 67 output_shape.shape().DebugString())); 68 OP_REQUIRES(c, output_shape.NumElements() == num_dims, 69 errors::InvalidArgument( 70 "output_shape has incorrect number of elements: ", 71 output_shape.NumElements(), " should be: ", num_dims)); 72 73 // sparse_values 74 const Tensor& sparse_values = c->input(2); 75 const int64 num_values = sparse_values.NumElements(); 76 OP_REQUIRES(c, 77 sparse_values.dims() == 0 || 78 (sparse_values.dims() == 1 && num_values == num_elems), 79 errors::InvalidArgument("sparse_values has incorrect shape ", 80 sparse_values.shape().DebugString(), 81 ", should be [] or [", num_elems, "]")); 82 83 // default_value 84 const Tensor& default_value = c->input(3); 85 OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()), 86 errors::InvalidArgument("default_value should be a scalar.")); 87 88 auto output_shape_vec = output_shape.flat<Index>(); 89 TensorShape output_tensor_shape; 90 OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(output_shape_vec.data(), 91 output_shape_vec.size(), 92 &output_tensor_shape)); 93 Tensor* output = nullptr; 94 OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output)); 95 96 TensorShape ix_shape({num_elems, num_dims}); 97 Tensor indices_shaped(DT_INT64, ix_shape); 98 if (indices.dtype() == DT_INT64) { 99 CHECK(indices_shaped.CopyFrom(indices, ix_shape)); 100 } else { 101 indices_shaped.matrix<int64>() = 102 indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>(); 103 } 104 105 // If we received a scalar, we'll need to create a new 106 // tensor with copies of the values as a vec. 107 // TODO(ebrevdo): find a way to avoid this temp allocation. 108 Tensor sparse_values_b; 109 110 if (TensorShapeUtils::IsScalar(sparse_values.shape())) { 111 OP_REQUIRES_OK( 112 c, c->allocate_temp(DataTypeToEnum<T>::value, 113 TensorShape({num_elems}), &sparse_values_b)); 114 sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()()); 115 } else { 116 sparse_values_b = sparse_values; 117 } 118 119 // Assume SparseTensor is lexicographically sorted. 120 gtl::InlinedVector<int64, 8> order(output->shape().dims()); 121 std::iota(order.begin(), order.end(), 0); 122 sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(), 123 order); 124 125 if (validate_indices_) { 126 OP_REQUIRES_OK(c, st.IndicesValid()); 127 } 128 129 output->flat<T>().setConstant(default_value.scalar<T>()()); 130 OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */), 131 errors::InvalidArgument( 132 "Indices are not valid (out of bounds). Shape: ", 133 output->shape().DebugString())); 134 } 135 136 private: 137 bool validate_indices_; 138 }; 139 140 #define REGISTER_KERNELS(type, index_type) \ 141 REGISTER_KERNEL_BUILDER(Name("SparseToDense") \ 142 .Device(DEVICE_CPU) \ 143 .TypeConstraint<type>("T") \ 144 .TypeConstraint<index_type>("Tindices"), \ 145 SparseToDense<type, index_type>); 146 147 #define REGISTER_KERNELS_ALL(type) \ 148 REGISTER_KERNELS(type, int32); \ 149 REGISTER_KERNELS(type, int64); 150 151 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL); 152 REGISTER_KERNELS_ALL(bool); 153 REGISTER_KERNELS_ALL(string); 154 155 #undef REGISTER_KERNELS_ALL 156 #undef REGISTER_KERNELS 157 158 } // namespace tensorflow 159