1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Ops for operating with sets. They are not checked in 17 // to TensorFlow because we would first like to demonstrate successful 18 // end-to-end use of these ops in eval and polush the api a bit like taking two 19 // SparseTensor rather than on edense and one sparse. 20 21 #define EIGEN_USE_THREADS 22 23 #include <algorithm> 24 #include <numeric> 25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it 26 // to work with string values. 27 #include <set> 28 #include <string> 29 30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_util.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/platform/env.h" 38 #include "tensorflow/core/util/sparse/sparse_tensor.h" 39 40 namespace tensorflow { 41 42 using ShapeArray = sparse::SparseTensor::ShapeArray; 43 using VarDimArray = sparse::SparseTensor::VarDimArray; 44 45 // Validate rank >= 2. 46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) { 47 const auto rank = shape.dims(); 48 OP_REQUIRES(ctx, rank >= 2, 49 errors::InvalidArgument("Invalid rank ", rank, ".")); 50 } 51 52 // Return group shape, which is the 1st n-1 dimensions of shape. 53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { 54 if (input_shape.size() < 2) { 55 // TODO(irving): Why can't 2 be 1 here? 56 return errors::InvalidArgument("Shape [", str_util::Join(input_shape, ","), 57 "] has rank ", input_shape.size(), " < 2"); 58 } 59 // grouped_shape is input_shape[:-1] 60 *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1); 61 return Status::OK(); 62 } 63 64 // Build `SparseTensor` from indices, values, and shape in inputs 65 // [base_index, base_index + 3), and validate its rank and indices. 66 sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx, 67 const int32 base_index, 68 bool validate_indices) { 69 // Assume row-major order. 70 const TensorShape shape = 71 TensorShape(ctx->input(base_index + 2).vec<int64>()); 72 CheckRankAtLeast2(ctx, shape); 73 std::vector<int64> order(shape.dims()); 74 std::iota(order.begin(), order.end(), 0); 75 76 const sparse::SparseTensor st(ctx->input(base_index), 77 ctx->input(base_index + 1), shape, order); 78 if (validate_indices) { 79 Status s = st.IndicesValid(); 80 if (!s.ok()) ctx->SetStatus(s); 81 } 82 return st; 83 } 84 85 // TODO(ptucker): CheckGroup is just a sanity check on the result of 86 // SparseTensor.group, consider removing. 87 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group 88 // was created, and is used to sanity check the indices in `group'. 89 template <typename T> 90 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group, 91 const VarDimArray& sparse_tensor_shape) { 92 const auto& indices = group.indices(); 93 const auto& values = group.values<T>(); 94 95 // Sanity check: group is non-empty, and indices and values are same size. 96 const auto num_values = values.dimension(0); 97 OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group.")); 98 OP_REQUIRES( 99 ctx, indices.dimension(0) == num_values, 100 errors::Internal("shape[0] of group indices ", indices.dimension(0), 101 " != values ", num_values, ".")); 102 103 // Sanity check: valid indices. 104 const auto group_rank = indices.dimension(1); 105 const auto expected_rank = sparse_tensor_shape.size(); 106 OP_REQUIRES(ctx, expected_rank == group_rank, 107 errors::Internal("Rank expected ", expected_rank, ", got ", 108 group_rank, ".")); 109 for (int32 j = 0; j < expected_rank; ++j) { 110 const auto dim_size = sparse_tensor_shape[j]; 111 OP_REQUIRES( 112 ctx, dim_size > 0, 113 errors::Internal("Invalid dim_size[", j, "] = ", dim_size, ".")); 114 for (int64 i = 0; i < num_values; ++i) { 115 const auto index = indices(i, j); 116 OP_REQUIRES(ctx, dim_size > index, 117 errors::Internal("indices[", i, ", ", j, "] expected < ", 118 dim_size, ", got ", index, ".")); 119 } 120 } 121 } 122 123 // This lets us calculate the row-major index into flattened output. 124 const ShapeArray Strides(const VarDimArray& shape) { 125 ShapeArray result(shape.size()); 126 int64 product = 1; 127 for (int i = shape.size() - 1; i >= 0; --i) { 128 result[i] = product; 129 product *= shape[i]; 130 } 131 return result; 132 } 133 134 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to 135 // eliminate the intermediate `values` data structure - iterate once to 136 // determine `num_values`, allocate output tensors, then write results directly 137 // to output tensors. 138 139 // TODO(ptucker): Consider sharding work across multiple threads. See 140 // SparseCrossOp for an example. 141 142 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of 143 // group indices (i.e., values for all but the last dimension of `output_shape`) 144 // to set values, each of which will occupy the last dimension of 145 // `output_shape`. 146 template <typename T> 147 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape, 148 const int64 num_values, 149 const std::map<std::vector<int64>, std::set<T>>& sets) { 150 // Allocate 3 output tensors for sparse data. 151 Tensor *out_indices_t, *out_values_t, *out_shape_t; 152 OP_REQUIRES_OK(ctx, ctx->allocate_output( 153 0, TensorShape({num_values, output_shape.dims()}), 154 &out_indices_t)); 155 OP_REQUIRES_OK( 156 ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t)); 157 OP_REQUIRES_OK(ctx, ctx->allocate_output( 158 2, TensorShape({output_shape.dims()}), &out_shape_t)); 159 auto out_indices_mat = out_indices_t->matrix<int64>(); 160 auto out_values_flat = out_values_t->vec<T>(); 161 162 // For each set, write its indices and values to output tensors. 163 int64 value_index = 0; 164 for (auto it = sets.begin(); it != sets.end(); ++it) { 165 const auto& group_indices = it->first; 166 OP_REQUIRES( 167 ctx, group_indices.size() == output_shape.dims() - 1, 168 errors::Internal("Invalid number of indices ", group_indices.size(), 169 ", expected ", output_shape.dims() - 1, ".")); 170 const auto& set = it->second; 171 172 // For each set item, write its indices and value to output tensors. 173 int64 group_value_index = 0; 174 for (auto value = set.begin(); value != set.end(); 175 ++value, ++value_index, ++group_value_index) { 176 // First n-1 dimensions are the group, last dimension is the position in 177 // the set. 178 for (int32 i = 0; i < group_indices.size(); ++i) { 179 out_indices_mat(value_index, i) = group_indices[i]; 180 } 181 out_indices_mat(value_index, group_indices.size()) = group_value_index; 182 183 out_values_flat(value_index) = *value; 184 } 185 } 186 187 // Write output shape. 188 auto out_shape_flat = out_shape_t->vec<int64>(); 189 for (int32 i = 0; i < output_shape.dims(); ++i) { 190 out_shape_flat(i) = output_shape.dim_size(i); 191 } 192 } 193 194 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) { 195 bool result; 196 if (ctx->GetAttr("validate_indices", &result).ok()) { 197 return result; 198 } 199 return true; 200 } 201 202 // Populate `result` set from group in `tensor`. "Group" is defined by 203 // `group_indices`, which are values for the first n-1 dimensions of 204 // `input_tensor`. `input_strides` is provided to avoid recalculating it 205 // multiple times, and is used to calculate the flat index into `input_tensor` 206 // values. 207 template <typename T> 208 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor, 209 const VarDimArray& input_strides, 210 const std::vector<int64>& group_indices, 211 std::set<T>* result) { 212 OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1, 213 errors::Internal("group_indices.size ", group_indices.size(), 214 ", != input_strides.size-1 ", 215 input_strides.size() - 1, ".")); 216 result->clear(); 217 auto input_flat = input_tensor.flat<T>(); 218 const auto start = std::inner_product( 219 group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL); 220 const TensorShape& input_shape = input_tensor.shape(); 221 const auto end = start + input_shape.dim_size(input_shape.dims() - 1); 222 for (int64 i = start; i < end; ++i) { 223 result->insert(input_flat(i)); 224 } 225 } 226 227 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the 228 // `SparseTensor` from which group was created, and is used to sanity check the 229 // indices in `group'. 230 template <typename T> 231 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group, 232 const VarDimArray& sparse_tensor_shape, 233 std::set<T>* result) { 234 CheckGroup<T>(ctx, group, sparse_tensor_shape); 235 result->clear(); 236 const auto& group_values = group.values<T>(); 237 for (int64 i = 0; i < group_values.size(); ++i) { 238 result->insert(group_values(i)); 239 } 240 } 241 242 template <typename T> 243 class SetSizeOp : public OpKernel { 244 public: 245 explicit SetSizeOp(OpKernelConstruction* ctx) 246 : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {} 247 248 void Compute(OpKernelContext* ctx) override; 249 250 private: 251 const bool validate_indices_; 252 }; 253 254 template <typename T> 255 void SetSizeOp<T>::Compute(OpKernelContext* ctx) { 256 const sparse::SparseTensor set_st = 257 SparseTensorFromContext(ctx, 0, validate_indices_); 258 259 // Output shape is same as input except for last dimension, which reduces to 260 // the set size of values along that dimension. 261 ShapeArray output_shape; 262 OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape)); 263 const auto output_strides = Strides(output_shape); 264 265 TensorShape output_shape_ts; 266 OP_REQUIRES_OK(ctx, 267 TensorShapeUtils::MakeShape(output_shape, &output_shape_ts)); 268 Tensor* out_t; 269 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t)); 270 auto out = out_t->flat<int32>(); 271 out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0)); 272 273 // Group by all but last dimension, create a set of group values, and add set 274 // size to output. 275 VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1); 276 std::set<T> group_set; 277 for (const auto& group : set_st.group(group_ix)) { 278 PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set); 279 280 const auto group_key = group.group(); 281 const auto output_index = std::inner_product( 282 group_key.begin(), group_key.end(), output_strides.begin(), 0LL); 283 out(output_index) = group_set.size(); 284 } 285 } 286 287 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \ 288 REGISTER_KERNEL_BUILDER( \ 289 Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 290 SetSizeOp<T>); 291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8); 292 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16); 293 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32); 294 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64); 295 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8); 296 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16); 297 _SET_SIZE_REGISTER_KERNEL_BUILDER(string); 298 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER 299 300 enum InputTypes { 301 DENSE_DENSE = 0, 302 DENSE_SPARSE = 1, 303 SPARSE_SPARSE = 2, 304 }; 305 306 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 }; 307 308 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) { 309 string set_operation_str; 310 if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) { 311 ctx->CtxFailure(errors::InvalidArgument("Missing set_operation.")); 312 } else { 313 std::transform(set_operation_str.begin(), set_operation_str.end(), 314 set_operation_str.begin(), ::tolower); 315 if ("a-b" == set_operation_str) { 316 return A_MINUS_B; 317 } 318 if ("b-a" == set_operation_str) { 319 return B_MINUS_A; 320 } 321 if ("intersection" == set_operation_str) { 322 return INTERSECTION; 323 } 324 if ("union" != set_operation_str) { 325 ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ", 326 set_operation_str, ".")); 327 } 328 } 329 // NOTE: This is not the default, this function fails if no 'set_operation' 330 // attribute is provided. 331 return UNION; 332 } 333 334 // Abstract base class for performing set operations across the last dimension 335 // of 2 input tensors. 336 template <typename T> 337 class SetOperationOp : public OpKernel { 338 public: 339 SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types) 340 : OpKernel(ctx), 341 set_operation_(SetOperationFromContext(ctx)), 342 validate_indices_(ValidateIndicesFromContext(ctx)), 343 input_types_(input_types) {} 344 345 void Compute(OpKernelContext* ctx) override; 346 347 private: 348 void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2, 349 std::set<T>* result) const; 350 void ComputeDenseToDense(OpKernelContext* ctx) const; 351 void ComputeDenseToSparse(OpKernelContext* ctx) const; 352 void ComputeSparseToSparse(OpKernelContext* ctx) const; 353 const SetOperation set_operation_; 354 const bool validate_indices_; 355 const InputTypes input_types_; 356 }; 357 358 template <typename T> 359 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1, 360 const std::set<T>& set2, 361 std::set<T>* result) const { 362 switch (set_operation_) { 363 case A_MINUS_B: 364 std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(), 365 std::inserter(*result, result->begin())); 366 break; 367 case B_MINUS_A: 368 std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(), 369 std::inserter(*result, result->begin())); 370 break; 371 case INTERSECTION: 372 std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), 373 std::inserter(*result, result->begin())); 374 break; 375 case UNION: 376 std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(), 377 std::inserter(*result, result->begin())); 378 break; 379 } 380 } 381 382 // Validate shapes have the same dimensions. 383 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { 384 if (shape1 != shape2) { 385 return errors::InvalidArgument("Mismatched shapes [", 386 str_util::Join(shape1, ","), "] vs [", 387 str_util::Join(shape2, ","), "]"); 388 } 389 return Status::OK(); 390 } 391 392 // Validate ranks are the same, and all but last dimension are the same. 393 // Return GroupShape. 394 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2, 395 ShapeArray* group_shape) { 396 ShapeArray group_shape_1; 397 TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1)); 398 ShapeArray group_shape_2; 399 TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2)); 400 TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2)); 401 *group_shape = group_shape_1; 402 return Status::OK(); 403 } 404 405 // Split `flat_group_index` into separate dimensions based on `group_shape`. 406 void PopulateGroupIndices(const int64 flat_group_index, VarDimArray group_shape, 407 std::vector<int64>* group_indices) { 408 group_indices->clear(); 409 int64 running_flat_group_index = flat_group_index; 410 for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0; 411 --group_dim_index) { 412 const auto group_dim = group_shape[group_dim_index]; 413 group_indices->insert(group_indices->begin(), 414 running_flat_group_index % group_dim); 415 running_flat_group_index /= group_dim; 416 } 417 } 418 419 ShapeArray TensorShapeToArray(const TensorShape& t) { 420 ShapeArray vec(t.dims()); 421 for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i); 422 return vec; 423 }; 424 425 // `ctx` contains set1 and set2 dense tensors. 426 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, 427 // and outputing the result `SparseTensor`. A "group" is a collection of values 428 // with the same first n-1 dimensions in set1 and set2. 429 template <typename T> 430 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const { 431 const Tensor& set1_t = ctx->input(0); 432 const Tensor& set2_t = ctx->input(1); 433 // The following should stay in sync with `_dense_to_dense_shape` shape 434 // assertions in python/ops/set_ops.py, and `SetShapeFn` for 435 // `DenseToDenseSetOperation` in ops/set_ops.cc. 436 ShapeArray group_shape; 437 const auto shape1 = TensorShapeToArray(set1_t.shape()); 438 const auto shape2 = TensorShapeToArray(set2_t.shape()); 439 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape)); 440 441 const auto set1_strides = Strides(shape1); 442 const auto set2_strides = Strides(shape2); 443 444 std::map<std::vector<int64>, std::set<T>> group_sets; 445 int64 num_result_values = 0; 446 int64 max_set_size = 0; 447 448 std::set<T> set1_group_set; 449 std::set<T> set2_group_set; 450 std::vector<int64> group_indices; 451 int64 num_elements; 452 OP_REQUIRES_OK(ctx, 453 TensorShapeUtils::NumElements(group_shape, &num_elements)); 454 for (int64 flat_group_index = 0; flat_group_index < num_elements; 455 ++flat_group_index) { 456 PopulateGroupIndices(flat_group_index, group_shape, &group_indices); 457 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, 458 &set1_group_set); 459 PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices, 460 &set2_group_set); 461 462 std::set<T> group_set; 463 ApplySetOperation(set1_group_set, set2_group_set, &group_set); 464 if (!group_set.empty()) { 465 group_sets[group_indices] = group_set; 466 const auto set_size = group_set.size(); 467 if (set_size > max_set_size) { 468 max_set_size = set_size; 469 } 470 num_result_values += set_size; 471 } 472 } 473 474 TensorShape output_shape; 475 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); 476 output_shape.AddDim(max_set_size); 477 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); 478 } 479 480 // `ctx` contains dense set1 and sparse set2 tensors. 481 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, 482 // and outputing the result `SparseTensor`. A "group" is a collection of values 483 // with the same first n-1 dimensions in set1 and set2. 484 template <typename T> 485 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const { 486 const Tensor& set1_t = ctx->input(0); 487 const sparse::SparseTensor set2_st = 488 SparseTensorFromContext(ctx, 1, validate_indices_); 489 // The following should stay in sync with `_dense_to_sparse_shape` shape 490 // assertions in python/ops/set_ops.py, and `SetShapeFn` for 491 // `DenseToSparseSetOperation` in ops/set_ops.cc. 492 ShapeArray group_shape; 493 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()), 494 set2_st.shape(), &group_shape)); 495 496 const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape())); 497 498 std::map<std::vector<int64>, std::set<T>> group_sets; 499 int64 num_result_values = 0; 500 int64 max_set_size = 0; 501 502 std::set<T> set1_group_set; 503 std::set<T> set2_group_set; 504 auto set2_grouper = set2_st.group( 505 VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1)); 506 auto set2_group_it = set2_grouper.begin(); 507 std::vector<int64> group_indices; 508 int64 num_elements; 509 OP_REQUIRES_OK(ctx, 510 TensorShapeUtils::NumElements(group_shape, &num_elements)); 511 for (int64 flat_group_index = 0; flat_group_index < num_elements; 512 ++flat_group_index) { 513 PopulateGroupIndices(flat_group_index, group_shape, &group_indices); 514 515 // Get values from set1. 516 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices, 517 &set1_group_set); 518 519 // Get values from set2, if applicable. 520 set2_group_set.clear(); 521 if (set2_group_it != set2_grouper.end()) { 522 const auto& group = *set2_group_it; 523 const auto set2_group_indices = group.group(); 524 OP_REQUIRES( 525 ctx, set2_group_indices.size() == group_indices.size(), 526 errors::InvalidArgument("Invalid number of group indices ", 527 set2_group_indices.size(), ", expected ", 528 group_indices.size(), ".")); 529 bool group_match = true; 530 for (int32 i = 0; group_match && (i < set2_group_indices.size()); ++i) { 531 if (set2_group_indices[i] != group_indices[i]) { 532 group_match = false; 533 } 534 } 535 if (group_match) { 536 PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(), 537 &set2_group_set); 538 ++set2_group_it; 539 } 540 } 541 542 std::set<T> group_set; 543 ApplySetOperation(set1_group_set, set2_group_set, &group_set); 544 if (!group_set.empty()) { 545 group_sets[group_indices] = group_set; 546 const auto set_size = group_set.size(); 547 if (set_size > max_set_size) { 548 max_set_size = set_size; 549 } 550 num_result_values += set_size; 551 } 552 } 553 554 TensorShape output_shape; 555 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); 556 output_shape.AddDim(max_set_size); 557 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); 558 } 559 560 // This is used to determine which group iterator is less than the other, based 561 // on row-major ordering of indices. 562 // An empty index list indicates end of iteration, which is interpreted as "max" 563 // for the purposes of comparison; i.e., non-empty < empty. 564 // Return 0 if both groups are empty, or both non-empty with the same values. 565 // Return <0 if set1 <= set2, or set2 is empty. 566 // Return >0 if set2 <= set1, or set1 is empty. 567 void CompareGroups(OpKernelContext* ctx, 568 const std::vector<int64>& set1_group_indices, 569 const std::vector<int64>& set2_group_indices, 570 int64* result) { 571 if (set1_group_indices.empty()) { 572 *result = set2_group_indices.empty() ? 0 : 1; 573 return; 574 } 575 if (set2_group_indices.empty()) { 576 *result = set1_group_indices.empty() ? 0 : -1; 577 return; 578 } 579 OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(), 580 errors::InvalidArgument("Mismatched group dims ", 581 set1_group_indices.size(), " vs ", 582 set2_group_indices.size(), ".")); 583 for (int32 i = 0; i < set1_group_indices.size(); ++i) { 584 *result = set1_group_indices[i] - set2_group_indices[i]; 585 if (*result != 0) { 586 return; 587 } 588 } 589 } 590 591 // Empty indices vector represents iteration end in `CompareGroups`. 592 const std::vector<int64> GROUP_ITER_END; 593 594 // `ctx` contains set1 and set2 sparse tensors. 595 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each, 596 // and outputing the result `SparseTensor`. A "group" is a collection of values 597 // with the same first n-1 dimensions in set1 and set2. 598 template <typename T> 599 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const { 600 const sparse::SparseTensor set1_st = 601 SparseTensorFromContext(ctx, 0, validate_indices_); 602 const sparse::SparseTensor set2_st = 603 SparseTensorFromContext(ctx, 3, validate_indices_); 604 // The following should stay in sync with `_sparse_to_sparse_shape` shape 605 // assertions in python/ops/set_ops.py, and `SetShapeFn` for 606 // `SparseToSparseSetOperation` in ops/set_ops.cc. 607 ShapeArray group_shape; 608 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(), 609 &group_shape)); 610 611 const ShapeArray set1_strides = Strides(set1_st.shape()); 612 const ShapeArray set2_strides = Strides(set2_st.shape()); 613 614 std::map<std::vector<int64>, std::set<T>> group_sets; 615 int64 num_result_values = 0; 616 int64 max_set_size = 0; 617 618 std::set<T> set1_group_set; 619 std::set<T> set2_group_set; 620 auto set1_grouper = set1_st.group( 621 VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1)); 622 auto set1_group_it = set1_grouper.begin(); 623 auto set2_grouper = set2_st.group( 624 VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1)); 625 auto set2_group_it = set2_grouper.begin(); 626 627 // Group by rows, and iterate over rows of both sets in parallel, creating a 628 // set for each row. 629 while ((set1_group_it != set1_grouper.end()) || 630 (set2_group_it != set2_grouper.end())) { 631 const std::vector<int64>& set1_group_indices = 632 (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END 633 : (*set1_group_it).group(); 634 const std::vector<int64>& set2_group_indices = 635 (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END 636 : (*set2_group_it).group(); 637 638 int64 compare_groups; 639 CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups); 640 const std::vector<int64>* group_indices = nullptr; 641 642 // Get values from set1, if applicable. 643 set1_group_set.clear(); 644 if (compare_groups <= 0) { 645 PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(), 646 &set1_group_set); 647 ++set1_group_it; 648 group_indices = &set1_group_indices; 649 } 650 651 // Get values from set2, if applicable. 652 set2_group_set.clear(); 653 if (compare_groups >= 0) { 654 PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(), 655 &set2_group_set); 656 ++set2_group_it; 657 group_indices = &set2_group_indices; 658 } 659 660 std::set<T> group_set; 661 ApplySetOperation(set1_group_set, set2_group_set, &group_set); 662 if (!group_set.empty()) { 663 group_sets[*group_indices] = group_set; 664 const auto set_size = group_set.size(); 665 if (set_size > max_set_size) { 666 max_set_size = set_size; 667 } 668 num_result_values += set_size; 669 } 670 } 671 672 TensorShape output_shape; 673 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape)); 674 output_shape.AddDim(max_set_size); 675 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets); 676 } 677 678 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result 679 // sparse tendor with [b, n3] values, where each row `i` contains the result of 680 // the set operation on elements from set1[i] and set2[i]. `n3` is the number 681 // of elements in that result row. 682 template <typename T> 683 void SetOperationOp<T>::Compute(OpKernelContext* ctx) { 684 switch (input_types_) { 685 case DENSE_DENSE: 686 ComputeDenseToDense(ctx); 687 break; 688 case DENSE_SPARSE: 689 ComputeDenseToSparse(ctx); 690 break; 691 case SPARSE_SPARSE: 692 ComputeSparseToSparse(ctx); 693 break; 694 } 695 } 696 697 template <typename T> 698 class DenseToDenseSetOperationOp : public SetOperationOp<T> { 699 public: 700 explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx) 701 : SetOperationOp<T>(ctx, DENSE_DENSE) {} 702 }; 703 704 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ 705 REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \ 706 .Device(DEVICE_CPU) \ 707 .TypeConstraint<T>("T"), \ 708 DenseToDenseSetOperationOp<T>); 709 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); 710 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); 711 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); 712 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); 713 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); 714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); 715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); 716 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER 717 718 template <typename T> 719 class DenseToSparseSetOperationOp : public SetOperationOp<T> { 720 public: 721 explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx) 722 : SetOperationOp<T>(ctx, DENSE_SPARSE) {} 723 }; 724 725 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ 726 REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \ 727 .Device(DEVICE_CPU) \ 728 .TypeConstraint<T>("T"), \ 729 DenseToSparseSetOperationOp<T>); 730 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); 731 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); 732 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); 733 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); 734 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); 735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); 736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); 737 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER 738 739 template <typename T> 740 class SparseToSparseSetOperationOp : public SetOperationOp<T> { 741 public: 742 explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx) 743 : SetOperationOp<T>(ctx, SPARSE_SPARSE) {} 744 }; 745 746 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \ 747 REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \ 748 .Device(DEVICE_CPU) \ 749 .TypeConstraint<T>("T"), \ 750 SparseToSparseSetOperationOp<T>); 751 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8); 752 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16); 753 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32); 754 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64); 755 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8); 756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16); 757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string); 758 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER 759 760 } // namespace tensorflow 761