Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Contains OP to generate sparse crosses.
     17 #include <assert.h>
     18 #include <limits>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 #include "tensorflow/core/framework/op_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/stringpiece.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/fingerprint.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 
     36 namespace {
     37 // An interface that represents a column with batches.
     38 template <typename InternalType>
     39 class ColumnInterface {
     40  public:
     41   // Returns the number of features in the specified batch.
     42   virtual int64 FeatureCount(int64 batch) const = 0;
     43 
     44   // Returns the fingerprint of nth feature from the specified batch.
     45   virtual InternalType Feature(int64 batch, int64 n) const = 0;
     46 
     47   virtual ~ColumnInterface() {}
     48 };
     49 
     50 // A column that is backed by a sparse tensor.
     51 template <typename InternalType>
     52 class SparseTensorColumn : public ColumnInterface<InternalType> {
     53  public:
     54   SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
     55                      std::vector<int64> feature_start_indices)
     56       : values_(values),
     57         feature_counts_(std::move(feature_counts)),
     58         feature_start_indices_(std::move(feature_start_indices)) {
     59     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
     60   }
     61 
     62   int64 FeatureCount(int64 batch) const override {
     63     return feature_counts_[batch];
     64   }
     65 
     66   InternalType Feature(int64 batch, int64 n) const override;
     67 
     68   ~SparseTensorColumn() override {}
     69 
     70  private:
     71   const Tensor& values_;
     72   std::vector<int64> feature_counts_;
     73   std::vector<int64> feature_start_indices_;
     74 };
     75 
     76 // InternalType is int64 only when using HashCrosser.
     77 template <>
     78 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
     79   const int64 start = feature_start_indices_[batch];
     80   if (DT_STRING == values_.dtype())
     81     return Fingerprint64(values_.vec<string>().data()[start + n]);
     82   return values_.vec<int64>().data()[start + n];
     83 }
     84 
     85 // InternalType is string or StringPiece when using StringCrosser.
     86 template <>
     87 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
     88   const int64 start = feature_start_indices_[batch];
     89   if (DT_STRING == values_.dtype())
     90     return values_.vec<string>().data()[start + n];
     91   return std::to_string(values_.vec<int64>().data()[start + n]);
     92 }
     93 
     94 template <>
     95 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
     96                                                      int64 n) const {
     97   const int64 start = feature_start_indices_[batch];
     98   return values_.vec<string>().data()[start + n];
     99 }
    100 
    101 // A column that is backed by a dense tensor.
    102 template <typename InternalType>
    103 class DenseTensorColumn : public ColumnInterface<InternalType> {
    104  public:
    105   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
    106 
    107   int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
    108 
    109   InternalType Feature(int64 batch, int64 n) const override;
    110 
    111   ~DenseTensorColumn() override {}
    112 
    113  private:
    114   const Tensor& tensor_;
    115 };
    116 
    117 // InternalType is int64 only when using HashCrosser.
    118 template <>
    119 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
    120   if (DT_STRING == tensor_.dtype())
    121     return Fingerprint64(tensor_.matrix<string>()(batch, n));
    122   return tensor_.matrix<int64>()(batch, n);
    123 }
    124 
    125 // Internal type is string or StringPiece when using StringCrosser.
    126 template <>
    127 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
    128   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
    129   return std::to_string(tensor_.matrix<int64>()(batch, n));
    130 }
    131 
    132 template <>
    133 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
    134                                                     int64 n) const {
    135   return tensor_.matrix<string>()(batch, n);
    136 }
    137 
    138 // Updates Output tensors with sparse crosses.
    139 template <typename OutType>
    140 class OutputUpdater {
    141  public:
    142   OutputUpdater(const std::vector<int64>& output_start_indices,
    143                 Tensor* indices_out, Tensor* values_out)
    144       : output_start_indices_(output_start_indices),
    145         indices_out_(indices_out),
    146         values_out_(values_out) {}
    147 
    148   void Update(const int64 batch_index, const int64 cross_count,
    149               const OutType& cross) const {
    150     const int64 output_index = output_start_indices_[batch_index] + cross_count;
    151 
    152     auto indices_matrix = indices_out_->matrix<int64>();
    153     indices_matrix(output_index, 0) = batch_index;
    154     indices_matrix(output_index, 1) = cross_count;
    155 
    156     auto value_vec = values_out_->vec<OutType>();
    157     value_vec(output_index) = cross;
    158   }
    159 
    160  private:
    161   const std::vector<int64>& output_start_indices_;
    162   Tensor* indices_out_;
    163   Tensor* values_out_;
    164 };
    165 
    166 // Generates the sparse crosses as concatenation of strings.
    167 template <typename InternalType>
    168 class StringCrosser {
    169  public:
    170   StringCrosser(const std::vector<
    171                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
    172                 const int64 num_buckets_unused, const uint64 hash_key_unused)
    173       : columns_(columns) {}
    174 
    175   string Generate(const int64 batch_index,
    176                   const std::vector<int>& permutation) const {
    177     static const auto k_feature_separator = "_X_";
    178 
    179     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
    180     for (size_t i = 0; i < permutation.size(); i++) {
    181       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
    182     }
    183     // TODO(zakaria): this will copy the string twice, might effect
    184     // performance.
    185     return str_util::Join(cross_vec, k_feature_separator);
    186   }
    187 
    188  private:
    189   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
    190 };
    191 
    192 // Generates the sparse crosses as nested hash to avoid string manipulations.
    193 class HashCrosser {
    194  public:
    195   HashCrosser(
    196       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
    197       const int64 num_buckets, const uint64 hash_key_unused)
    198       : columns_(columns), num_buckets_(num_buckets) {}
    199 
    200   int64 Generate(const int64 batch_index,
    201                  const std::vector<int>& permutation) const {
    202     // Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
    203     static const int64 kInitialHashSeed = 0xDECAFCAFFE;
    204 
    205     uint64 hashed_output = kInitialHashSeed;
    206     for (size_t i = 0; i < permutation.size(); ++i) {
    207       int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
    208       hashed_output = HashCombine(hashed_output, hash_i);
    209     }
    210     if (num_buckets_ > 0) {
    211       return hashed_output % num_buckets_;
    212     } else {
    213       // To prevent negative output we take modulo to max int64.
    214       return hashed_output % std::numeric_limits<int64>::max();
    215     }
    216   }
    217 
    218  private:
    219   static int64 HashCombine(int64 a, int64 b) {
    220     return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4));
    221   }
    222 
    223   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
    224   const int64 num_buckets_;
    225 };
    226 
    227 // Generates the sparse crosses as nested hash to avoid string manipulations.
    228 class HashCrosserV2 {
    229  public:
    230   HashCrosserV2(
    231       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
    232       const int64 num_buckets, const uint64 hash_key)
    233       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
    234 
    235   int64 Generate(const int64 batch_index,
    236                  const std::vector<int>& permutation) const {
    237     // Do the fingerprint concatenation on uint64.
    238     uint64 hashed_output = hash_key_;
    239     for (size_t i = 0; i < permutation.size(); ++i) {
    240       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
    241       hashed_output = FingerprintCat64(hashed_output, hash_i);
    242     }
    243     // The return value is int64 based on the number of buckets.
    244     if (num_buckets_ > 0) {
    245       return hashed_output % num_buckets_;
    246     } else {
    247       // To prevent negative output we take modulo to max int64.
    248       return hashed_output % std::numeric_limits<int64>::max();
    249     }
    250   }
    251 
    252  private:
    253   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
    254   const int64 num_buckets_;
    255   const uint64 hash_key_;
    256 };
    257 
    258 // ProductIterator generates cartesian products based on indices.
    259 template <typename InternalType>
    260 class ProductIterator {
    261  public:
    262   explicit ProductIterator(
    263       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    264           columns,
    265       int64 batch_index)
    266       : columns_(columns), batch_index_(batch_index) {
    267     next_permutation_.resize(columns_.size(), 0);
    268     // Sets has_next_ to false if any feature column has 0 features.
    269     has_next_ = true;
    270     for (size_t i = 0; i < columns_.size(); i++) {
    271       if (columns_[i]->FeatureCount(batch_index_) == 0) {
    272         has_next_ = false;
    273         break;
    274       }
    275     }
    276   }
    277 
    278   std::vector<int> Next() {
    279     std::vector<int> permutation(next_permutation_);
    280 
    281     // Generates next permutation, if available.
    282     bool carry = true;
    283     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
    284       if (carry) {
    285         next_permutation_[i] = next_permutation_[i] + 1;
    286       }
    287       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
    288         next_permutation_[i] = 0;
    289       } else {
    290         carry = false;
    291         break;
    292       }
    293     }
    294     has_next_ = !carry;
    295     return permutation;
    296   }
    297 
    298   bool HasNext() { return has_next_; }
    299 
    300  private:
    301   bool has_next_;
    302   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
    303   const int64 batch_index_;
    304   std::vector<int> next_permutation_;
    305 };
    306 
    307 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
    308 struct CrossTraits;
    309 
    310 template <typename InternalType, bool VERSION_2>
    311 struct CrossTraits<false, InternalType, VERSION_2> {
    312   typedef StringCrosser<InternalType> Crosser;
    313   typedef OutputUpdater<string> Updater;
    314 };
    315 
    316 template <>
    317 struct CrossTraits<true, int64, false> {
    318   typedef HashCrosser Crosser;
    319   typedef OutputUpdater<int64> Updater;
    320 };
    321 
    322 template <>
    323 struct CrossTraits<true, int64, true> {
    324   typedef HashCrosserV2 Crosser;
    325   typedef OutputUpdater<int64> Updater;
    326 };
    327 }  // namespace
    328 
    329 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
    330 class SparseFeatureCrossOp : public OpKernel {
    331  public:
    332   explicit SparseFeatureCrossOp(OpKernelConstruction* context)
    333       : OpKernel(context) {
    334     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
    335     if (VERSION_2) {
    336       // Read signed_hash_key_ as int64 since uint64 attributes are not
    337       // supported by REGISTER_OP.
    338       int64 signed_hash_key_;
    339       OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
    340       hash_key_ = static_cast<uint64>(signed_hash_key_);
    341     }
    342   }
    343 
    344   void Compute(OpKernelContext* context) override {
    345     OpInputList indices_list_in;
    346     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
    347     OpInputList values_list_in;
    348     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
    349     OpInputList shapes_list_in;
    350     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
    351     OpInputList dense_list_in;
    352     OP_REQUIRES_OK(context, context->input_list("dense", &dense_list_in));
    353 
    354     ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
    355                   dense_list_in);
    356 
    357     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
    358         GenerateColumnsFromInput(indices_list_in, values_list_in,
    359                                  shapes_list_in, dense_list_in);
    360 
    361     typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Crosser
    362         crosser(columns, num_buckets_, hash_key_);
    363     Tensor* indices_out;
    364     Tensor* values_out;
    365     Tensor* shape_out;
    366     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    367     std::vector<int64> output_start_indices(batch_size);
    368     CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
    369                         &shape_out, &output_start_indices);
    370 
    371     typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Updater
    372         updater(output_start_indices, indices_out, values_out);
    373     auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
    374       for (int b = begin; b < end; b++) {
    375         ProductIterator<InternalType> product_iterator(columns, b);
    376         int64 cross_count = 0;
    377         while (product_iterator.HasNext()) {
    378           const auto permutation = product_iterator.Next();
    379           updater.Update(b, cross_count, crosser.Generate(b, permutation));
    380           cross_count++;
    381         }
    382       }
    383     };
    384 
    385     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
    386     // TODO(zakaria): optimize kCostPerUnit
    387     const int kCostPerUnit = 5000 * indices_list_in.size();
    388     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
    389           kCostPerUnit, do_work);
    390   }
    391 
    392  private:
    393   // Validates input tensors.
    394   void ValidateInput(OpKernelContext* context,
    395                      const OpInputList& indices_list_in,
    396                      const OpInputList& values_list_in,
    397                      const OpInputList& shapes_list_in,
    398                      const OpInputList& dense_list_in) {
    399     const auto size = indices_list_in.size();
    400     // Validates indices_list_in OpInputList.
    401     for (int i = 0; i < size; i++) {
    402       OP_REQUIRES(
    403           context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
    404           errors::InvalidArgument(
    405               "Input indices should be a matrix but received shape ",
    406               indices_list_in[i].shape().DebugString(), " at position ", i));
    407       OP_REQUIRES(
    408           context, indices_list_in[i].shape().dim_size(1) == 2,
    409           errors::InvalidArgument("Expected D2 of index to be 2 got ",
    410                                   indices_list_in[i].shape().dim_size(1),
    411                                   " at position ", i));
    412     }
    413 
    414     // Validates values_list_in OpInputList.
    415     OP_REQUIRES(
    416         context, values_list_in.size() == size,
    417         errors::InvalidArgument("Expected ", size, " input values, got ",
    418                                 values_list_in.size()));
    419     for (int i = 0; i < size; i++) {
    420       OP_REQUIRES(
    421           context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
    422           errors::InvalidArgument(
    423               "Input values should be a std::vector but received shape ",
    424               values_list_in[i].shape().DebugString(), " at position ", i));
    425       OP_REQUIRES(
    426           context,
    427           indices_list_in[i].shape().dim_size(0) ==
    428               values_list_in[i].shape().dim_size(0),
    429           errors::InvalidArgument(
    430               "Expected size of values to be ",
    431               indices_list_in[i].shape().dim_size(0), " got ",
    432               values_list_in[i].shape().dim_size(0), " at position ", i));
    433     }
    434 
    435     // Validates shapes_list_in OpInputList
    436     OP_REQUIRES(
    437         context, shapes_list_in.size() == size,
    438         errors::InvalidArgument("Expected ", size, " input shapes, got ",
    439                                 shapes_list_in.size()));
    440     const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    441     for (int i = 0; i < size; i++) {
    442       OP_REQUIRES(
    443           context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
    444           errors::InvalidArgument(
    445               "Input shapes should be a std::vector but received shape ",
    446               shapes_list_in[i].shape().DebugString(), " at position ", i));
    447 
    448       OP_REQUIRES(
    449           context, shapes_list_in[i].vec<int64>().size() == 2,
    450           errors::InvalidArgument("shape should imply a 2D tensor, but got ",
    451                                   shapes_list_in[i].shape().DebugString(),
    452                                   " at position ", i));
    453       OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
    454                   errors::InvalidArgument(
    455                       "Expected batch size ", batch_size, " got ",
    456                       shapes_list_in[i].vec<int64>()(0), " at position ", i));
    457     }
    458 
    459     // Validates dense_list_in OpInputList
    460     for (int i = 0; i < dense_list_in.size(); ++i) {
    461       OP_REQUIRES(
    462           context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
    463           errors::InvalidArgument(
    464               "Dense inputs should be a matrix but received shape ",
    465               indices_list_in[i].shape().DebugString(), " at position ", i));
    466       OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
    467                   errors::InvalidArgument("Expected batch size ", batch_size,
    468                                           " got ", dense_list_in[i].dim_size(0),
    469                                           " at dense tensor ", i));
    470     }
    471   }
    472 
    473   // Calculate the batch size from either the shapes input or the dense input.
    474   int64 CalculateBatchSize(const OpInputList& shapes_list_in,
    475                            const OpInputList& dense_list_in) {
    476     if (shapes_list_in.size() > 0) {
    477       return shapes_list_in[0].vec<int64>()(0);
    478     }
    479 
    480     if (dense_list_in.size() > 0) {
    481       return dense_list_in[0].dim_size(0);
    482     }
    483 
    484     return 0;
    485   }
    486 
    487   // Generate the columns given the sparse and dense inputs.
    488   std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
    489   GenerateColumnsFromInput(const OpInputList& indices_list_in,
    490                            const OpInputList& values_list_in,
    491                            const OpInputList& shapes_list_in,
    492                            const OpInputList& dense_list_in) {
    493     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
    494     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    495     const int64 number_of_columns = shapes_list_in.size();
    496 
    497     std::vector<std::vector<int64>> feature_counts(number_of_columns,
    498                                                    std::vector<int64>());
    499     std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
    500                                                           std::vector<int64>());
    501 
    502     ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
    503                        &feature_start_indices);
    504 
    505     columns.reserve(values_list_in.size());
    506     for (int i = 0; i < values_list_in.size(); ++i) {
    507       columns.emplace_back(new SparseTensorColumn<InternalType>(
    508           values_list_in[i], std::move(feature_counts[i]),
    509           std::move(feature_start_indices[i])));
    510     }
    511     for (int i = 0; i < dense_list_in.size(); ++i) {
    512       columns.emplace_back(
    513           new DenseTensorColumn<InternalType>(dense_list_in[i]));
    514     }
    515 
    516     return columns;
    517   }
    518 
    519   // Extracts data about the features and populates feature data.
    520   void ExtractFeatureData(
    521       const OpInputList& indices_list_in, int64 batch_size,
    522       std::vector<std::vector<int64>>* feature_counts,
    523       std::vector<std::vector<int64>>* feature_start_indices) {
    524     gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
    525     for (int b = 0; b < batch_size; b++) {
    526       for (int i = 0; i < indices_list_in.size(); i++) {
    527         const auto indices = indices_list_in[i].matrix<int64>();
    528         int64 feature_count = 0;
    529         int64 start_index = current_row[i];
    530         // Loops until we reach next batch index for current feature column.
    531         while (current_row[i] < indices_list_in[i].dim_size(0) &&
    532                indices(current_row[i], 0) == b) {
    533           feature_count++;
    534           current_row[i]++;
    535         }
    536         (*feature_counts)[i].push_back(feature_count);
    537         (*feature_start_indices)[i].push_back(start_index);
    538       }
    539     }
    540   }
    541 
    542   // Allocates output tensors with proper size and sets the shape tensor of
    543   // the output SparseTensor.
    544   // It also output_start_indices which contains the start indices for each
    545   // input in the output SparseTensor.
    546   void CreateOutputTensors(
    547       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    548           columns,
    549       int64 batch_size, OpKernelContext* context, Tensor** indices_out,
    550       Tensor** values_out, Tensor** shape_out,
    551       std::vector<int64>* output_start_indices) {
    552     // Calculates dimensions for output tensors.
    553     int64 cross_count_total = 0;
    554     int64 max_cross_count = 0;
    555     for (int64 b = 0; b < batch_size; b++) {
    556       // For each input, sets starting indices in output SparseTensor
    557       (*output_start_indices)[b] = cross_count_total;
    558       const auto cross_count = CrossCountByBatchIndex(columns, b);
    559       max_cross_count = std::max(max_cross_count, cross_count);
    560       cross_count_total += cross_count;
    561     }
    562 
    563     // Allocates tensors.
    564     OP_REQUIRES_OK(context,
    565                    context->allocate_output(
    566                        0, TensorShape({cross_count_total, 2}), indices_out));
    567     OP_REQUIRES_OK(context,
    568                    context->allocate_output(1, TensorShape({cross_count_total}),
    569                                             values_out));
    570     OP_REQUIRES_OK(context,
    571                    context->allocate_output(2, TensorShape({2}), shape_out));
    572 
    573     // Sets shape.
    574     auto shape_vec = (*shape_out)->vec<int64>();
    575     shape_vec(0) = batch_size;
    576     shape_vec(1) = max_cross_count;
    577   }
    578 
    579   // Returns number of crosses for a given batch_index
    580   int64 CrossCountByBatchIndex(
    581       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    582           columns,
    583       int batch_index) {
    584     int64 cross_count = 1;
    585     for (size_t i = 0; i < columns.size(); i++) {
    586       const auto feature_count = columns[i]->FeatureCount(batch_index);
    587       // If one column is missing any feature, there won't be any cross.
    588       if (feature_count == 0) {
    589         return 0;
    590       }
    591       cross_count *= feature_count;
    592     }
    593     return cross_count;
    594   }
    595   int64 num_buckets_;
    596   uint64 hash_key_;
    597 };
    598 
    599 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
    600                             .Device(DEVICE_CPU)
    601                             .TypeConstraint<string>("out_type")
    602                             .TypeConstraint<string>("internal_type"),
    603                         SparseFeatureCrossOp<false, StringPiece, false>);
    604 
    605 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
    606                             .Device(DEVICE_CPU)
    607                             .TypeConstraint<string>("out_type")
    608                             .TypeConstraint<int64>("internal_type"),
    609                         SparseFeatureCrossOp<false, string, false>);
    610 
    611 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
    612                             .Device(DEVICE_CPU)
    613                             .TypeConstraint<int64>("out_type")
    614                             .TypeConstraint<string>("internal_type"),
    615                         SparseFeatureCrossOp<true, int64, false>);
    616 
    617 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
    618                             .Device(DEVICE_CPU)
    619                             .TypeConstraint<int64>("out_type")
    620                             .TypeConstraint<int64>("internal_type"),
    621                         SparseFeatureCrossOp<true, int64, false>);
    622 
    623 // The following builders enable FingerprintCat64 concatenation for the
    624 // crosses features.
    625 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
    626                             .Device(DEVICE_CPU)
    627                             .TypeConstraint<string>("out_type")
    628                             .TypeConstraint<string>("internal_type"),
    629                         SparseFeatureCrossOp<false, StringPiece, true>);
    630 
    631 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
    632                             .Device(DEVICE_CPU)
    633                             .TypeConstraint<string>("out_type")
    634                             .TypeConstraint<int64>("internal_type"),
    635                         SparseFeatureCrossOp<false, string, true>);
    636 
    637 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
    638                             .Device(DEVICE_CPU)
    639                             .TypeConstraint<int64>("out_type")
    640                             .TypeConstraint<string>("internal_type"),
    641                         SparseFeatureCrossOp<true, int64, true>);
    642 
    643 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
    644                             .Device(DEVICE_CPU)
    645                             .TypeConstraint<int64>("out_type")
    646                             .TypeConstraint<int64>("internal_type"),
    647                         SparseFeatureCrossOp<true, int64, true>);
    648 
    649 }  // namespace tensorflow
    650