Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Ops for operating with sets. They are not checked in
     17 // to TensorFlow because we would first like to demonstrate successful
     18 // end-to-end use of these ops in eval and polush the api a bit like taking two
     19 // SparseTensor rather than on edense and one sparse.
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #include <algorithm>
     24 #include <numeric>
     25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it
     26 // to work with string values.
     27 #include <set>
     28 #include <string>
     29 
     30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/register_types.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_util.h"
     35 #include "tensorflow/core/framework/types.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/platform/env.h"
     38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     39 
     40 namespace tensorflow {
     41 
     42 using ShapeArray = sparse::SparseTensor::ShapeArray;
     43 using VarDimArray = sparse::SparseTensor::VarDimArray;
     44 
     45 // Validate rank >= 2.
     46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
     47   const auto rank = shape.dims();
     48   OP_REQUIRES(ctx, rank >= 2,
     49               errors::InvalidArgument("Invalid rank ", rank, "."));
     50 }
     51 
     52 // Return group shape, which is the 1st n-1 dimensions of shape.
     53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
     54   if (input_shape.size() < 2) {
     55     // TODO(irving): Why can't 2 be 1 here?
     56     return errors::InvalidArgument("Shape [", str_util::Join(input_shape, ","),
     57                                    "] has rank ", input_shape.size(), " < 2");
     58   }
     59   // grouped_shape is input_shape[:-1]
     60   *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
     61   return Status::OK();
     62 }
     63 
     64 // Build `SparseTensor` from indices, values, and shape in inputs
     65 // [base_index, base_index + 3), and validate its rank and indices.
     66 sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx,
     67                                              const int32 base_index,
     68                                              bool validate_indices) {
     69   // Assume row-major order.
     70   const TensorShape shape =
     71       TensorShape(ctx->input(base_index + 2).vec<int64>());
     72   CheckRankAtLeast2(ctx, shape);
     73   std::vector<int64> order(shape.dims());
     74   std::iota(order.begin(), order.end(), 0);
     75 
     76   const sparse::SparseTensor st(ctx->input(base_index),
     77                                 ctx->input(base_index + 1), shape, order);
     78   if (validate_indices) {
     79     Status s = st.IndicesValid();
     80     if (!s.ok()) ctx->SetStatus(s);
     81   }
     82   return st;
     83 }
     84 
     85 // TODO(ptucker): CheckGroup is just a sanity check on the result of
     86 // SparseTensor.group, consider removing.
     87 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
     88 // was created, and is used to sanity check the indices in `group'.
     89 template <typename T>
     90 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
     91                 const VarDimArray& sparse_tensor_shape) {
     92   const auto& indices = group.indices();
     93   const auto& values = group.values<T>();
     94 
     95   // Sanity check: group is non-empty, and indices and values are same size.
     96   const auto num_values = values.dimension(0);
     97   OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
     98   OP_REQUIRES(
     99       ctx, indices.dimension(0) == num_values,
    100       errors::Internal("shape[0] of group indices ", indices.dimension(0),
    101                        " != values ", num_values, "."));
    102 
    103   // Sanity check: valid indices.
    104   const auto group_rank = indices.dimension(1);
    105   const auto expected_rank = sparse_tensor_shape.size();
    106   OP_REQUIRES(ctx, expected_rank == group_rank,
    107               errors::Internal("Rank expected ", expected_rank, ", got ",
    108                                group_rank, "."));
    109   for (int32 j = 0; j < expected_rank; ++j) {
    110     const auto dim_size = sparse_tensor_shape[j];
    111     OP_REQUIRES(
    112         ctx, dim_size > 0,
    113         errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
    114     for (int64 i = 0; i < num_values; ++i) {
    115       const auto index = indices(i, j);
    116       OP_REQUIRES(ctx, dim_size > index,
    117                   errors::Internal("indices[", i, ", ", j, "] expected < ",
    118                                    dim_size, ", got ", index, "."));
    119     }
    120   }
    121 }
    122 
    123 // This lets us calculate the row-major index into flattened output.
    124 const ShapeArray Strides(const VarDimArray& shape) {
    125   ShapeArray result(shape.size());
    126   int64 product = 1;
    127   for (int i = shape.size() - 1; i >= 0; --i) {
    128     result[i] = product;
    129     product *= shape[i];
    130   }
    131   return result;
    132 }
    133 
    134 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
    135 // eliminate the intermediate `values` data structure - iterate once to
    136 // determine `num_values`, allocate output tensors, then write results directly
    137 // to output tensors.
    138 
    139 // TODO(ptucker): Consider sharding work across multiple threads. See
    140 // SparseCrossOp for an example.
    141 
    142 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of
    143 // group indices (i.e., values for all but the last dimension of `output_shape`)
    144 // to set values, each of which will occupy the last dimension of
    145 // `output_shape`.
    146 template <typename T>
    147 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape,
    148                         const int64 num_values,
    149                         const std::map<std::vector<int64>, std::set<T>>& sets) {
    150   // Allocate 3 output tensors for sparse data.
    151   Tensor *out_indices_t, *out_values_t, *out_shape_t;
    152   OP_REQUIRES_OK(ctx, ctx->allocate_output(
    153                           0, TensorShape({num_values, output_shape.dims()}),
    154                           &out_indices_t));
    155   OP_REQUIRES_OK(
    156       ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
    157   OP_REQUIRES_OK(ctx, ctx->allocate_output(
    158                           2, TensorShape({output_shape.dims()}), &out_shape_t));
    159   auto out_indices_mat = out_indices_t->matrix<int64>();
    160   auto out_values_flat = out_values_t->vec<T>();
    161 
    162   // For each set, write its indices and values to output tensors.
    163   int64 value_index = 0;
    164   for (auto it = sets.begin(); it != sets.end(); ++it) {
    165     const auto& group_indices = it->first;
    166     OP_REQUIRES(
    167         ctx, group_indices.size() == output_shape.dims() - 1,
    168         errors::Internal("Invalid number of indices ", group_indices.size(),
    169                          ", expected ", output_shape.dims() - 1, "."));
    170     const auto& set = it->second;
    171 
    172     // For each set item, write its indices and value to output tensors.
    173     int64 group_value_index = 0;
    174     for (auto value = set.begin(); value != set.end();
    175          ++value, ++value_index, ++group_value_index) {
    176       // First n-1 dimensions are the group, last dimension is the position in
    177       // the set.
    178       for (int32 i = 0; i < group_indices.size(); ++i) {
    179         out_indices_mat(value_index, i) = group_indices[i];
    180       }
    181       out_indices_mat(value_index, group_indices.size()) = group_value_index;
    182 
    183       out_values_flat(value_index) = *value;
    184     }
    185   }
    186 
    187   // Write output shape.
    188   auto out_shape_flat = out_shape_t->vec<int64>();
    189   for (int32 i = 0; i < output_shape.dims(); ++i) {
    190     out_shape_flat(i) = output_shape.dim_size(i);
    191   }
    192 }
    193 
    194 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
    195   bool result;
    196   if (ctx->GetAttr("validate_indices", &result).ok()) {
    197     return result;
    198   }
    199   return true;
    200 }
    201 
    202 // Populate `result` set from group in `tensor`. "Group" is defined by
    203 // `group_indices`, which are values for the first n-1 dimensions of
    204 // `input_tensor`. `input_strides` is provided to avoid recalculating it
    205 // multiple times, and is used to calculate the flat index into `input_tensor`
    206 // values.
    207 template <typename T>
    208 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
    209                             const VarDimArray& input_strides,
    210                             const std::vector<int64>& group_indices,
    211                             std::set<T>* result) {
    212   OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
    213               errors::Internal("group_indices.size ", group_indices.size(),
    214                                ", !=  input_strides.size-1 ",
    215                                input_strides.size() - 1, "."));
    216   result->clear();
    217   auto input_flat = input_tensor.flat<T>();
    218   const auto start = std::inner_product(
    219       group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
    220   const TensorShape& input_shape = input_tensor.shape();
    221   const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
    222   for (int64 i = start; i < end; ++i) {
    223     result->insert(input_flat(i));
    224   }
    225 }
    226 
    227 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
    228 // `SparseTensor` from which group was created, and is used to sanity check the
    229 // indices in `group'.
    230 template <typename T>
    231 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
    232                              const VarDimArray& sparse_tensor_shape,
    233                              std::set<T>* result) {
    234   CheckGroup<T>(ctx, group, sparse_tensor_shape);
    235   result->clear();
    236   const auto& group_values = group.values<T>();
    237   for (int64 i = 0; i < group_values.size(); ++i) {
    238     result->insert(group_values(i));
    239   }
    240 }
    241 
    242 template <typename T>
    243 class SetSizeOp : public OpKernel {
    244  public:
    245   explicit SetSizeOp(OpKernelConstruction* ctx)
    246       : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
    247 
    248   void Compute(OpKernelContext* ctx) override;
    249 
    250  private:
    251   const bool validate_indices_;
    252 };
    253 
    254 template <typename T>
    255 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
    256   const sparse::SparseTensor set_st =
    257       SparseTensorFromContext(ctx, 0, validate_indices_);
    258 
    259   // Output shape is same as input except for last dimension, which reduces to
    260   // the set size of values along that dimension.
    261   ShapeArray output_shape;
    262   OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
    263   const auto output_strides = Strides(output_shape);
    264 
    265   TensorShape output_shape_ts;
    266   OP_REQUIRES_OK(ctx,
    267                  TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
    268   Tensor* out_t;
    269   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
    270   auto out = out_t->flat<int32>();
    271   out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
    272 
    273   // Group by all but last dimension, create a set of group values, and add set
    274   // size to output.
    275   VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1);
    276   std::set<T> group_set;
    277   for (const auto& group : set_st.group(group_ix)) {
    278     PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
    279 
    280     const auto group_key = group.group();
    281     const auto output_index = std::inner_product(
    282         group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
    283     out(output_index) = group_set.size();
    284   }
    285 }
    286 
    287 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T)                     \
    288   REGISTER_KERNEL_BUILDER(                                       \
    289       Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    290       SetSizeOp<T>);
    291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
    292 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
    293 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
    294 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64);
    295 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
    296 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
    297 _SET_SIZE_REGISTER_KERNEL_BUILDER(string);
    298 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
    299 
    300 enum InputTypes {
    301   DENSE_DENSE = 0,
    302   DENSE_SPARSE = 1,
    303   SPARSE_SPARSE = 2,
    304 };
    305 
    306 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
    307 
    308 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
    309   string set_operation_str;
    310   if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
    311     ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
    312   } else {
    313     std::transform(set_operation_str.begin(), set_operation_str.end(),
    314                    set_operation_str.begin(), ::tolower);
    315     if ("a-b" == set_operation_str) {
    316       return A_MINUS_B;
    317     }
    318     if ("b-a" == set_operation_str) {
    319       return B_MINUS_A;
    320     }
    321     if ("intersection" == set_operation_str) {
    322       return INTERSECTION;
    323     }
    324     if ("union" != set_operation_str) {
    325       ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
    326                                               set_operation_str, "."));
    327     }
    328   }
    329   // NOTE: This is not the default, this function fails if no 'set_operation'
    330   // attribute is provided.
    331   return UNION;
    332 }
    333 
    334 // Abstract base class for performing set operations across the last dimension
    335 // of 2 input tensors.
    336 template <typename T>
    337 class SetOperationOp : public OpKernel {
    338  public:
    339   SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
    340       : OpKernel(ctx),
    341         set_operation_(SetOperationFromContext(ctx)),
    342         validate_indices_(ValidateIndicesFromContext(ctx)),
    343         input_types_(input_types) {}
    344 
    345   void Compute(OpKernelContext* ctx) override;
    346 
    347  private:
    348   void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2,
    349                          std::set<T>* result) const;
    350   void ComputeDenseToDense(OpKernelContext* ctx) const;
    351   void ComputeDenseToSparse(OpKernelContext* ctx) const;
    352   void ComputeSparseToSparse(OpKernelContext* ctx) const;
    353   const SetOperation set_operation_;
    354   const bool validate_indices_;
    355   const InputTypes input_types_;
    356 };
    357 
    358 template <typename T>
    359 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1,
    360                                           const std::set<T>& set2,
    361                                           std::set<T>* result) const {
    362   switch (set_operation_) {
    363     case A_MINUS_B:
    364       std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(),
    365                           std::inserter(*result, result->begin()));
    366       break;
    367     case B_MINUS_A:
    368       std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(),
    369                           std::inserter(*result, result->begin()));
    370       break;
    371     case INTERSECTION:
    372       std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
    373                             std::inserter(*result, result->begin()));
    374       break;
    375     case UNION:
    376       std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(),
    377                      std::inserter(*result, result->begin()));
    378       break;
    379   }
    380 }
    381 
    382 // Validate shapes have the same dimensions.
    383 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
    384   if (shape1 != shape2) {
    385     return errors::InvalidArgument("Mismatched shapes [",
    386                                    str_util::Join(shape1, ","), "] vs [",
    387                                    str_util::Join(shape2, ","), "]");
    388   }
    389   return Status::OK();
    390 }
    391 
    392 // Validate ranks are the same, and all but last dimension are the same.
    393 // Return GroupShape.
    394 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
    395                             ShapeArray* group_shape) {
    396   ShapeArray group_shape_1;
    397   TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
    398   ShapeArray group_shape_2;
    399   TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
    400   TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
    401   *group_shape = group_shape_1;
    402   return Status::OK();
    403 }
    404 
    405 // Split `flat_group_index` into separate dimensions based on `group_shape`.
    406 void PopulateGroupIndices(const int64 flat_group_index, VarDimArray group_shape,
    407                           std::vector<int64>* group_indices) {
    408   group_indices->clear();
    409   int64 running_flat_group_index = flat_group_index;
    410   for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
    411        --group_dim_index) {
    412     const auto group_dim = group_shape[group_dim_index];
    413     group_indices->insert(group_indices->begin(),
    414                           running_flat_group_index % group_dim);
    415     running_flat_group_index /= group_dim;
    416   }
    417 }
    418 
    419 ShapeArray TensorShapeToArray(const TensorShape& t) {
    420   ShapeArray vec(t.dims());
    421   for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
    422   return vec;
    423 };
    424 
    425 // `ctx` contains set1 and set2 dense tensors.
    426 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
    427 // and outputing the result `SparseTensor`. A "group" is a collection of values
    428 // with the same first n-1 dimensions in set1 and set2.
    429 template <typename T>
    430 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
    431   const Tensor& set1_t = ctx->input(0);
    432   const Tensor& set2_t = ctx->input(1);
    433   // The following should stay in sync with `_dense_to_dense_shape` shape
    434   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
    435   // `DenseToDenseSetOperation` in ops/set_ops.cc.
    436   ShapeArray group_shape;
    437   const auto shape1 = TensorShapeToArray(set1_t.shape());
    438   const auto shape2 = TensorShapeToArray(set2_t.shape());
    439   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
    440 
    441   const auto set1_strides = Strides(shape1);
    442   const auto set2_strides = Strides(shape2);
    443 
    444   std::map<std::vector<int64>, std::set<T>> group_sets;
    445   int64 num_result_values = 0;
    446   int64 max_set_size = 0;
    447 
    448   std::set<T> set1_group_set;
    449   std::set<T> set2_group_set;
    450   std::vector<int64> group_indices;
    451   int64 num_elements;
    452   OP_REQUIRES_OK(ctx,
    453                  TensorShapeUtils::NumElements(group_shape, &num_elements));
    454   for (int64 flat_group_index = 0; flat_group_index < num_elements;
    455        ++flat_group_index) {
    456     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
    457     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
    458                               &set1_group_set);
    459     PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
    460                               &set2_group_set);
    461 
    462     std::set<T> group_set;
    463     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
    464     if (!group_set.empty()) {
    465       group_sets[group_indices] = group_set;
    466       const auto set_size = group_set.size();
    467       if (set_size > max_set_size) {
    468         max_set_size = set_size;
    469       }
    470       num_result_values += set_size;
    471     }
    472   }
    473 
    474   TensorShape output_shape;
    475   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
    476   output_shape.AddDim(max_set_size);
    477   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
    478 }
    479 
    480 // `ctx` contains dense set1 and sparse set2 tensors.
    481 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
    482 // and outputing the result `SparseTensor`. A "group" is a collection of values
    483 // with the same first n-1 dimensions in set1 and set2.
    484 template <typename T>
    485 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
    486   const Tensor& set1_t = ctx->input(0);
    487   const sparse::SparseTensor set2_st =
    488       SparseTensorFromContext(ctx, 1, validate_indices_);
    489   // The following should stay in sync with `_dense_to_sparse_shape` shape
    490   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
    491   // `DenseToSparseSetOperation` in ops/set_ops.cc.
    492   ShapeArray group_shape;
    493   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
    494                                            set2_st.shape(), &group_shape));
    495 
    496   const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
    497 
    498   std::map<std::vector<int64>, std::set<T>> group_sets;
    499   int64 num_result_values = 0;
    500   int64 max_set_size = 0;
    501 
    502   std::set<T> set1_group_set;
    503   std::set<T> set2_group_set;
    504   auto set2_grouper = set2_st.group(
    505       VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
    506   auto set2_group_it = set2_grouper.begin();
    507   std::vector<int64> group_indices;
    508   int64 num_elements;
    509   OP_REQUIRES_OK(ctx,
    510                  TensorShapeUtils::NumElements(group_shape, &num_elements));
    511   for (int64 flat_group_index = 0; flat_group_index < num_elements;
    512        ++flat_group_index) {
    513     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
    514 
    515     // Get values from set1.
    516     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
    517                               &set1_group_set);
    518 
    519     // Get values from set2, if applicable.
    520     set2_group_set.clear();
    521     if (set2_group_it != set2_grouper.end()) {
    522       const auto& group = *set2_group_it;
    523       const auto set2_group_indices = group.group();
    524       OP_REQUIRES(
    525           ctx, set2_group_indices.size() == group_indices.size(),
    526           errors::InvalidArgument("Invalid number of group indices ",
    527                                   set2_group_indices.size(), ", expected ",
    528                                   group_indices.size(), "."));
    529       bool group_match = true;
    530       for (int32 i = 0; group_match && (i < set2_group_indices.size()); ++i) {
    531         if (set2_group_indices[i] != group_indices[i]) {
    532           group_match = false;
    533         }
    534       }
    535       if (group_match) {
    536         PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
    537                                    &set2_group_set);
    538         ++set2_group_it;
    539       }
    540     }
    541 
    542     std::set<T> group_set;
    543     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
    544     if (!group_set.empty()) {
    545       group_sets[group_indices] = group_set;
    546       const auto set_size = group_set.size();
    547       if (set_size > max_set_size) {
    548         max_set_size = set_size;
    549       }
    550       num_result_values += set_size;
    551     }
    552   }
    553 
    554   TensorShape output_shape;
    555   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
    556   output_shape.AddDim(max_set_size);
    557   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
    558 }
    559 
    560 // This is used to determine which group iterator is less than the other, based
    561 // on row-major ordering of indices.
    562 // An empty index list indicates end of iteration, which is interpreted as "max"
    563 // for the purposes of comparison; i.e., non-empty < empty.
    564 // Return 0 if both groups are empty, or both non-empty with the same values.
    565 // Return <0 if set1 <= set2, or set2 is empty.
    566 // Return >0 if set2 <= set1, or set1 is empty.
    567 void CompareGroups(OpKernelContext* ctx,
    568                    const std::vector<int64>& set1_group_indices,
    569                    const std::vector<int64>& set2_group_indices,
    570                    int64* result) {
    571   if (set1_group_indices.empty()) {
    572     *result = set2_group_indices.empty() ? 0 : 1;
    573     return;
    574   }
    575   if (set2_group_indices.empty()) {
    576     *result = set1_group_indices.empty() ? 0 : -1;
    577     return;
    578   }
    579   OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
    580               errors::InvalidArgument("Mismatched group dims ",
    581                                       set1_group_indices.size(), " vs ",
    582                                       set2_group_indices.size(), "."));
    583   for (int32 i = 0; i < set1_group_indices.size(); ++i) {
    584     *result = set1_group_indices[i] - set2_group_indices[i];
    585     if (*result != 0) {
    586       return;
    587     }
    588   }
    589 }
    590 
    591 // Empty indices vector represents iteration end in `CompareGroups`.
    592 const std::vector<int64> GROUP_ITER_END;
    593 
    594 // `ctx` contains set1 and set2 sparse tensors.
    595 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
    596 // and outputing the result `SparseTensor`. A "group" is a collection of values
    597 // with the same first n-1 dimensions in set1 and set2.
    598 template <typename T>
    599 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
    600   const sparse::SparseTensor set1_st =
    601       SparseTensorFromContext(ctx, 0, validate_indices_);
    602   const sparse::SparseTensor set2_st =
    603       SparseTensorFromContext(ctx, 3, validate_indices_);
    604   // The following should stay in sync with `_sparse_to_sparse_shape` shape
    605   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
    606   // `SparseToSparseSetOperation` in ops/set_ops.cc.
    607   ShapeArray group_shape;
    608   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
    609                                            &group_shape));
    610 
    611   const ShapeArray set1_strides = Strides(set1_st.shape());
    612   const ShapeArray set2_strides = Strides(set2_st.shape());
    613 
    614   std::map<std::vector<int64>, std::set<T>> group_sets;
    615   int64 num_result_values = 0;
    616   int64 max_set_size = 0;
    617 
    618   std::set<T> set1_group_set;
    619   std::set<T> set2_group_set;
    620   auto set1_grouper = set1_st.group(
    621       VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1));
    622   auto set1_group_it = set1_grouper.begin();
    623   auto set2_grouper = set2_st.group(
    624       VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
    625   auto set2_group_it = set2_grouper.begin();
    626 
    627   // Group by rows, and iterate over rows of both sets in parallel, creating a
    628   // set for each row.
    629   while ((set1_group_it != set1_grouper.end()) ||
    630          (set2_group_it != set2_grouper.end())) {
    631     const std::vector<int64>& set1_group_indices =
    632         (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END
    633                                               : (*set1_group_it).group();
    634     const std::vector<int64>& set2_group_indices =
    635         (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END
    636                                               : (*set2_group_it).group();
    637 
    638     int64 compare_groups;
    639     CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
    640     const std::vector<int64>* group_indices = nullptr;
    641 
    642     // Get values from set1, if applicable.
    643     set1_group_set.clear();
    644     if (compare_groups <= 0) {
    645       PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
    646                                  &set1_group_set);
    647       ++set1_group_it;
    648       group_indices = &set1_group_indices;
    649     }
    650 
    651     // Get values from set2, if applicable.
    652     set2_group_set.clear();
    653     if (compare_groups >= 0) {
    654       PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
    655                                  &set2_group_set);
    656       ++set2_group_it;
    657       group_indices = &set2_group_indices;
    658     }
    659 
    660     std::set<T> group_set;
    661     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
    662     if (!group_set.empty()) {
    663       group_sets[*group_indices] = group_set;
    664       const auto set_size = group_set.size();
    665       if (set_size > max_set_size) {
    666         max_set_size = set_size;
    667       }
    668       num_result_values += set_size;
    669     }
    670   }
    671 
    672   TensorShape output_shape;
    673   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
    674   output_shape.AddDim(max_set_size);
    675   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
    676 }
    677 
    678 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
    679 // sparse tendor with [b, n3] values, where each row `i` contains the result of
    680 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
    681 // of elements in that result row.
    682 template <typename T>
    683 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
    684   switch (input_types_) {
    685     case DENSE_DENSE:
    686       ComputeDenseToDense(ctx);
    687       break;
    688     case DENSE_SPARSE:
    689       ComputeDenseToSparse(ctx);
    690       break;
    691     case SPARSE_SPARSE:
    692       ComputeSparseToSparse(ctx);
    693       break;
    694   }
    695 }
    696 
    697 template <typename T>
    698 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
    699  public:
    700   explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
    701       : SetOperationOp<T>(ctx, DENSE_DENSE) {}
    702 };
    703 
    704 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
    705   REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation")       \
    706                               .Device(DEVICE_CPU)                \
    707                               .TypeConstraint<T>("T"),           \
    708                           DenseToDenseSetOperationOp<T>);
    709 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
    710 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
    711 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
    712 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
    713 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
    714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
    715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
    716 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
    717 
    718 template <typename T>
    719 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
    720  public:
    721   explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
    722       : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
    723 };
    724 
    725 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
    726   REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation")       \
    727                               .Device(DEVICE_CPU)                 \
    728                               .TypeConstraint<T>("T"),            \
    729                           DenseToSparseSetOperationOp<T>);
    730 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
    731 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
    732 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
    733 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
    734 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
    735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
    736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
    737 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
    738 
    739 template <typename T>
    740 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
    741  public:
    742   explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
    743       : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
    744 };
    745 
    746 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
    747   REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation")       \
    748                               .Device(DEVICE_CPU)                  \
    749                               .TypeConstraint<T>("T"),             \
    750                           SparseToSparseSetOperationOp<T>);
    751 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
    752 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
    753 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
    754 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
    755 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
    756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
    757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(string);
    758 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
    759 
    760 }  // namespace tensorflow
    761