1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_util.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/lib/gtl/inlined_vector.h" 34 #include "tensorflow/core/util/sparse/sparse_tensor.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 using sparse::SparseTensor; 41 42 class SparseTensorsMap : public ResourceBase { 43 public: 44 explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {} 45 46 string DebugString() override { return "A SparseTensorsMap"; } 47 48 typedef struct { 49 PersistentTensor indices; 50 PersistentTensor values; 51 gtl::InlinedVector<int64, 8> shape; 52 } PersistentSparseTensor; 53 54 Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, 55 int64* handle) { 56 PersistentTensor persistent_ix; 57 Tensor* ix; 58 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 59 sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix)); 60 *ix = sp.indices(); 61 62 PersistentTensor persistent_values; 63 Tensor* values; 64 TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(), 65 sp.indices().shape(), 66 &persistent_values, &values)); 67 *values = sp.values(); 68 { 69 mutex_lock l(mu_); 70 int64 unique_st_handle = counter_++; // increment is guarded on purpose 71 sp_tensors_[unique_st_handle] = PersistentSparseTensor{ 72 persistent_ix, persistent_values, 73 gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())}; 74 *handle = unique_st_handle; 75 } 76 return Status::OK(); 77 } 78 79 Status RetrieveAndClearSparseTensors( 80 OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles, 81 std::vector<SparseTensor>* sparse_tensors) { 82 sparse_tensors->clear(); 83 sparse_tensors->reserve(handles.size()); 84 { 85 mutex_lock l(mu_); 86 for (size_t i = 0; i < handles.size(); ++i) { 87 const int64 handle = handles(i); 88 auto sp_iter = sp_tensors_.find(handle); 89 if (sp_iter == sp_tensors_.end()) { 90 return errors::InvalidArgument( 91 "Unable to find SparseTensor: ", handle, " in map: ", name_); 92 } 93 const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx); 94 const Tensor* values = sp_iter->second.values.AccessTensor(ctx); 95 const auto& shape = sp_iter->second.shape; 96 sparse_tensors->emplace_back(*ix, *values, shape); 97 98 sp_tensors_.erase(sp_iter); 99 } 100 } 101 102 return Status::OK(); 103 } 104 105 protected: 106 ~SparseTensorsMap() override {} 107 108 private: 109 string name_; 110 111 mutex mu_; 112 int64 counter_ GUARDED_BY(mu_); 113 std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ GUARDED_BY(mu_); 114 }; 115 116 class SparseTensorAccessingOp : public OpKernel { 117 public: 118 typedef std::function<Status(SparseTensorsMap**)> CreatorCallback; 119 120 explicit SparseTensorAccessingOp(OpKernelConstruction* context) 121 : OpKernel(context), sparse_tensors_map_(nullptr) {} 122 123 protected: 124 ~SparseTensorAccessingOp() override { 125 if (sparse_tensors_map_) sparse_tensors_map_->Unref(); 126 } 127 128 Status GetMap(OpKernelContext* ctx, bool is_writing, 129 SparseTensorsMap** sparse_tensors_map) { 130 mutex_lock l(mu_); 131 132 if (sparse_tensors_map_) { 133 *sparse_tensors_map = sparse_tensors_map_; 134 return Status::OK(); 135 } 136 137 TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(), 138 is_writing /* use_node_name_as_default */)); 139 140 CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) { 141 SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name()); 142 *c = map; 143 return Status::OK(); 144 }; 145 146 TF_RETURN_IF_ERROR( 147 cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>( 148 cinfo_.container(), cinfo_.name(), &sparse_tensors_map_, 149 sparse_tensors_map_creator)); 150 151 *sparse_tensors_map = sparse_tensors_map_; 152 return Status::OK(); 153 } 154 155 private: 156 ContainerInfo cinfo_; 157 158 mutex mu_; 159 SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_); 160 }; 161 162 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp { 163 public: 164 explicit AddSparseToTensorsMapOp(OpKernelConstruction* context) 165 : SparseTensorAccessingOp(context) {} 166 167 void Compute(OpKernelContext* context) override { 168 const Tensor* input_indices; 169 const Tensor* input_values; 170 const Tensor* input_shape; 171 SparseTensorsMap* map; 172 173 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 174 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 175 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 176 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 177 178 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 179 errors::InvalidArgument( 180 "Input indices should be a matrix but received shape ", 181 input_indices->shape().DebugString())); 182 183 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 184 errors::InvalidArgument( 185 "Input values should be a vector but received shape ", 186 input_values->shape().DebugString())); 187 188 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 189 errors::InvalidArgument( 190 "Input shape should be a vector but received shape ", 191 input_shape->shape().DebugString())); 192 193 TensorShape input_shape_object; 194 OP_REQUIRES_OK(context, 195 TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(), 196 input_shape->NumElements(), 197 &input_shape_object)); 198 SparseTensor st(*input_indices, *input_values, input_shape_object); 199 int64 handle; 200 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle)); 201 202 Tensor sparse_handle(DT_INT64, TensorShape({})); 203 auto sparse_handle_t = sparse_handle.scalar<int64>(); 204 205 sparse_handle_t() = handle; 206 207 context->set_output(0, sparse_handle); 208 } 209 }; 210 211 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU), 212 AddSparseToTensorsMapOp); 213 214 template <typename T> 215 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp { 216 public: 217 explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context) 218 : SparseTensorAccessingOp(context) {} 219 220 void Compute(OpKernelContext* context) override { 221 const Tensor* input_indices; 222 const Tensor* input_values; 223 const Tensor* input_shape; 224 SparseTensorsMap* map; 225 226 OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices)); 227 OP_REQUIRES_OK(context, context->input("sparse_values", &input_values)); 228 OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape)); 229 OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map)); 230 231 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()), 232 errors::InvalidArgument( 233 "Input indices should be a matrix but received shape ", 234 input_indices->shape().DebugString())); 235 236 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()), 237 errors::InvalidArgument( 238 "Input values should be a vector but received shape ", 239 input_values->shape().DebugString())); 240 241 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()), 242 errors::InvalidArgument( 243 "Input shape should be a vector but received shape ", 244 input_shape->shape().DebugString())); 245 246 int rank = input_shape->NumElements(); 247 248 OP_REQUIRES( 249 context, rank > 1, 250 errors::InvalidArgument( 251 "Rank of input SparseTensor should be > 1, but saw rank: ", rank)); 252 253 TensorShape tensor_input_shape(input_shape->vec<int64>()); 254 gtl::InlinedVector<int64, 8> std_order(rank); 255 std::iota(std_order.begin(), std_order.end(), 0); 256 SparseTensor input_st(*input_indices, *input_values, tensor_input_shape, 257 std_order); 258 259 auto input_shape_t = input_shape->vec<int64>(); 260 const int64 N = input_shape_t(0); 261 262 Tensor sparse_handles(DT_INT64, TensorShape({N})); 263 auto sparse_handles_t = sparse_handles.vec<int64>(); 264 265 OP_REQUIRES_OK(context, input_st.IndicesValid()); 266 267 // We can generate the output shape proto string now, for all 268 // minibatch entries. 269 TensorShape output_shape; 270 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 271 input_shape_t.data() + 1, 272 input_shape->NumElements() - 1, &output_shape)); 273 274 // Get groups by minibatch dimension 275 std::unordered_set<int64> visited; 276 sparse::GroupIterable minibatch = input_st.group({0}); 277 for (const auto& subset : minibatch) { 278 const int64 b = subset.group()[0]; 279 visited.insert(b); 280 OP_REQUIRES( 281 context, b > -1 && b < N, 282 errors::InvalidArgument( 283 "Received unexpected column 0 value in input SparseTensor: ", b, 284 " < 0 or >= N (= ", N, ")")); 285 286 const auto indices = subset.indices(); 287 const auto values = subset.values<T>(); 288 const int64 num_entries = values.size(); 289 290 Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1}); 291 Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries}); 292 293 auto output_indices_t = output_indices.matrix<int64>(); 294 auto output_values_t = output_values.vec<T>(); 295 296 for (int i = 0; i < num_entries; ++i) { 297 for (int d = 1; d < rank; ++d) { 298 output_indices_t(i, d - 1) = indices(i, d); 299 } 300 output_values_t(i) = values(i); 301 } 302 303 SparseTensor st_i(output_indices, output_values, output_shape); 304 int64 handle; 305 OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle)); 306 sparse_handles_t(b) = handle; 307 } 308 309 // Fill in any gaps; we must provide an empty ST for batch entries 310 // the grouper didn't find. 311 if (visited.size() < N) { 312 Tensor empty_indices(DT_INT64, {0, rank - 1}); 313 Tensor empty_values(DataTypeToEnum<T>::value, {0}); 314 SparseTensor empty_st(empty_indices, empty_values, output_shape); 315 316 for (int64 b = 0; b < N; ++b) { 317 // We skipped this batch entry. 318 if (visited.find(b) == visited.end()) { 319 int64 handle; 320 OP_REQUIRES_OK(context, 321 map->AddSparseTensor(context, empty_st, &handle)); 322 sparse_handles_t(b) = handle; 323 } 324 } 325 } 326 327 context->set_output(0, sparse_handles); 328 } 329 }; 330 331 #define REGISTER_KERNELS(type) \ 332 REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \ 333 .Device(DEVICE_CPU) \ 334 .TypeConstraint<type>("T"), \ 335 AddManySparseToTensorsMapOp<type>) 336 337 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 338 #undef REGISTER_KERNELS 339 340 template <typename T> 341 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp { 342 public: 343 explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context) 344 : SparseTensorAccessingOp(context) {} 345 346 void Compute(OpKernelContext* context) override { 347 SparseTensorsMap* map = nullptr; 348 OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map)); 349 350 const Tensor& sparse_handles = context->input(0); 351 352 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()), 353 errors::InvalidArgument( 354 "sparse_handles should be a vector but received shape ", 355 sparse_handles.shape().DebugString())); 356 357 int64 N = sparse_handles.shape().dim_size(0); 358 359 OP_REQUIRES( 360 context, N > 0, 361 errors::InvalidArgument("Must have at least 1 serialized SparseTensor, " 362 "but input matrix has 0 rows")); 363 364 std::vector<Tensor> indices_to_concat; 365 std::vector<Tensor> values_to_concat; 366 std::vector<TensorShape> shapes_to_concat; 367 368 const auto& sparse_handles_t = sparse_handles.vec<int64>(); 369 370 std::vector<SparseTensor> sparse_tensors; 371 372 OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors( 373 context, sparse_handles_t, &sparse_tensors)); 374 375 for (int64 i = 0; i < N; ++i) { 376 const SparseTensor& st = sparse_tensors[i]; 377 const Tensor& output_indices = st.indices(); 378 const Tensor& output_values = st.values(); 379 const auto output_shape = st.shape(); 380 381 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), 382 errors::InvalidArgument( 383 "Expected sparse_handles[", i, 384 "] to represent an index matrix but received shape ", 385 output_indices.shape().DebugString())); 386 OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()), 387 errors::InvalidArgument( 388 "Expected sparse_handles[", i, 389 "] to represent a values vector but received shape ", 390 output_values.shape().DebugString())); 391 OP_REQUIRES( 392 context, DataTypeToEnum<T>::value == output_values.dtype(), 393 errors::InvalidArgument( 394 "Requested SparseTensor of type ", 395 DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i, 396 "].values.dtype() == ", DataTypeString(output_values.dtype()))); 397 398 int64 num_entries = output_indices.dim_size(0); 399 OP_REQUIRES(context, num_entries == output_values.dim_size(0), 400 errors::InvalidArgument( 401 "Expected row counts of SparseTensor[", i, 402 "].indices and SparseTensor[", i, 403 "].values to match but they do not: ", num_entries, 404 " vs. ", output_values.dim_size(0))); 405 int rank = output_indices.dim_size(1); 406 OP_REQUIRES( 407 context, rank == output_shape.size(), 408 errors::InvalidArgument("Expected column counts of SparseTensor[", i, 409 "].indices to match size of SparseTensor[", i, 410 "].shape " 411 "but they do not: ", 412 rank, " vs. ", output_shape.size())); 413 414 // Now we expand each SparseTensors' indices and shape by 415 // prefixing a dimension 416 Tensor expanded_indices( 417 DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)})); 418 Tensor expanded_shape(DT_INT64, TensorShape({1 + rank})); 419 const auto& output_indices_t = output_indices.matrix<int64>(); 420 auto expanded_indices_t = expanded_indices.matrix<int64>(); 421 auto expanded_shape_t = expanded_shape.vec<int64>(); 422 expanded_indices_t.chip<1>(0).setZero(); 423 Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1); 424 Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank); 425 expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t; 426 expanded_shape_t(0) = 1; 427 // TODO: copy shape from TensorShape to &expanded_shape_t(1) 428 // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1)); 429 for (int i = 0; i < rank; ++i) { 430 expanded_shape_t(i + 1) = output_shape[i]; 431 } 432 TensorShape expanded_tensor_shape(expanded_shape_t); 433 434 indices_to_concat.push_back(std::move(expanded_indices)); 435 values_to_concat.push_back(output_values); 436 shapes_to_concat.push_back(std::move(expanded_tensor_shape)); 437 } 438 439 int rank = -1; 440 for (int i = 0; i < N; ++i) { 441 if (rank < 0) rank = shapes_to_concat[i].dims(); 442 OP_REQUIRES(context, rank == shapes_to_concat[i].dims(), 443 errors::InvalidArgument( 444 "Inconsistent rank across SparseTensors: rank prior to " 445 "SparseTensor[", 446 i, "] was: ", rank, " but rank of SparseTensor[", i, 447 "] is: ", shapes_to_concat[i].dims())); 448 } 449 450 // SparseTensor::Concat requires consistent shape for all but the 451 // primary order dimension (dimension 0 in this case). So we get 452 // the maximum value across all the input SparseTensors for each 453 // dimension and use that. 454 TensorShape preconcat_shape(shapes_to_concat[0]); 455 for (int i = 0; i < N; ++i) { 456 for (int d = 0; d < rank; ++d) { 457 preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d), 458 shapes_to_concat[i].dim_size(d))); 459 } 460 } 461 462 // Dimension 0 is the primary dimension. 463 gtl::InlinedVector<int64, 8> std_order(rank); 464 std::iota(std_order.begin(), std_order.end(), 0); 465 466 std::vector<SparseTensor> tensors_to_concat; 467 tensors_to_concat.reserve(N); 468 for (int i = 0; i < N; ++i) { 469 tensors_to_concat.emplace_back(std::move(indices_to_concat[i]), 470 std::move(values_to_concat[i]), 471 preconcat_shape, std_order); 472 } 473 474 SparseTensor output(SparseTensor::Concat<T>(tensors_to_concat)); 475 476 Tensor final_output_shape(DT_INT64, TensorShape({output.dims()})); 477 478 std::copy_n(output.shape().data(), output.dims(), 479 final_output_shape.vec<int64>().data()); 480 481 context->set_output(0, output.indices()); 482 context->set_output(1, output.values()); 483 context->set_output(2, final_output_shape); 484 } 485 }; 486 487 #define REGISTER_KERNELS(type) \ 488 REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \ 489 .Device(DEVICE_CPU) \ 490 .TypeConstraint<type>("dtype"), \ 491 TakeManySparseFromTensorsMapOp<type>) 492 493 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 494 #undef REGISTER_KERNELS 495 496 } // namespace tensorflow 497