Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Contains OP to generate sparse crosses.
     17 #include <assert.h>
     18 #include <limits>
     19 #include <string>
     20 #include <vector>
     21 
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 #include "tensorflow/core/framework/op_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/stringpiece.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/fingerprint.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 namespace tensorflow {
     35 
     36 namespace {
     37 // An interface that represents a column with batches.
     38 template <typename InternalType>
     39 class ColumnInterface {
     40  public:
     41   // Returns the number of features in the specified batch.
     42   virtual int64 FeatureCount(int64 batch) const = 0;
     43 
     44   // Returns the fingerprint of nth feature from the specified batch.
     45   virtual InternalType Feature(int64 batch, int64 n) const = 0;
     46 
     47   virtual ~ColumnInterface() {}
     48 };
     49 
     50 // A column that is backed by a sparse tensor.
     51 template <typename InternalType>
     52 class SparseTensorColumn : public ColumnInterface<InternalType> {
     53  public:
     54   SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
     55                      std::vector<int64> feature_start_indices)
     56       : values_(values),
     57         feature_counts_(std::move(feature_counts)),
     58         feature_start_indices_(std::move(feature_start_indices)) {
     59     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
     60   }
     61 
     62   int64 FeatureCount(int64 batch) const override {
     63     return feature_counts_[batch];
     64   }
     65 
     66   InternalType Feature(int64 batch, int64 n) const override;
     67 
     68   ~SparseTensorColumn() override {}
     69 
     70  private:
     71   const Tensor& values_;
     72   std::vector<int64> feature_counts_;
     73   std::vector<int64> feature_start_indices_;
     74 };
     75 
     76 // InternalType is int64 only when using HashCrosser.
     77 template <>
     78 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
     79   const int64 start = feature_start_indices_[batch];
     80   if (DT_STRING == values_.dtype())
     81     return Fingerprint64(values_.vec<string>().data()[start + n]);
     82   return values_.vec<int64>().data()[start + n];
     83 }
     84 
     85 // InternalType is string or StringPiece when using StringCrosser.
     86 template <>
     87 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
     88   const int64 start = feature_start_indices_[batch];
     89   if (DT_STRING == values_.dtype())
     90     return values_.vec<string>().data()[start + n];
     91   return std::to_string(values_.vec<int64>().data()[start + n]);
     92 }
     93 
     94 template <>
     95 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
     96                                                      int64 n) const {
     97   const int64 start = feature_start_indices_[batch];
     98   return values_.vec<string>().data()[start + n];
     99 }
    100 
    101 // A column that is backed by a dense tensor.
    102 template <typename InternalType>
    103 class DenseTensorColumn : public ColumnInterface<InternalType> {
    104  public:
    105   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
    106 
    107   int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
    108 
    109   InternalType Feature(int64 batch, int64 n) const override;
    110 
    111   ~DenseTensorColumn() override {}
    112 
    113  private:
    114   const Tensor& tensor_;
    115 };
    116 
    117 // InternalType is int64 only when using HashCrosser.
    118 template <>
    119 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
    120   if (DT_STRING == tensor_.dtype())
    121     return Fingerprint64(tensor_.matrix<string>()(batch, n));
    122   return tensor_.matrix<int64>()(batch, n);
    123 }
    124 
    125 // Internal type is string or StringPiece when using StringCrosser.
    126 template <>
    127 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
    128   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
    129   return std::to_string(tensor_.matrix<int64>()(batch, n));
    130 }
    131 
    132 template <>
    133 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
    134                                                     int64 n) const {
    135   return tensor_.matrix<string>()(batch, n);
    136 }
    137 
    138 // Updates Output tensors with sparse crosses.
    139 template <typename OutType>
    140 class OutputUpdater {
    141  public:
    142   OutputUpdater(const std::vector<int64>& output_start_indices,
    143                 Tensor* indices_out, Tensor* values_out)
    144       : output_start_indices_(output_start_indices),
    145         indices_out_(indices_out),
    146         values_out_(values_out) {}
    147 
    148   void Update(const int64 batch_index, const int64 cross_count,
    149               const OutType& cross) const {
    150     const int64 output_index = output_start_indices_[batch_index] + cross_count;
    151 
    152     auto indices_matrix = indices_out_->matrix<int64>();
    153     indices_matrix(output_index, 0) = batch_index;
    154     indices_matrix(output_index, 1) = cross_count;
    155 
    156     auto value_vec = values_out_->vec<OutType>();
    157     value_vec(output_index) = cross;
    158   }
    159 
    160  private:
    161   const std::vector<int64>& output_start_indices_;
    162   Tensor* indices_out_;
    163   Tensor* values_out_;
    164 };
    165 
    166 // Generates the sparse crosses as concatenation of strings.
    167 template <typename InternalType>
    168 class StringCrosser {
    169  public:
    170   StringCrosser(const std::vector<
    171                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
    172                 const int64 num_buckets_unused, const uint64 hash_key_unused)
    173       : columns_(columns) {}
    174 
    175   string Generate(const int64 batch_index,
    176                   const std::vector<int>& permutation) const {
    177     static const auto k_feature_separator = "_X_";
    178 
    179     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
    180     for (int i = 0; i < permutation.size(); i++) {
    181       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
    182     }
    183     // TODO(zakaria): this will copy the string twice, might effect
    184     // performance.
    185     return str_util::Join(cross_vec, k_feature_separator);
    186   }
    187 
    188  private:
    189   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
    190 };
    191 
    192 // Generates the sparse crosses as nested hash to avoid string manipulations.
    193 class HashCrosser {
    194  public:
    195   HashCrosser(
    196       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
    197       const int64 num_buckets, const uint64 hash_key)
    198       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
    199 
    200   int64 Generate(const int64 batch_index,
    201                  const std::vector<int>& permutation) const {
    202     // Do the fingerprint concatenation on uint64.
    203     uint64 hashed_output = hash_key_;
    204     for (size_t i = 0; i < permutation.size(); ++i) {
    205       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
    206       hashed_output = FingerprintCat64(hashed_output, hash_i);
    207     }
    208     // The return value is int64 based on the number of buckets.
    209     if (num_buckets_ > 0) {
    210       return hashed_output % num_buckets_;
    211     } else {
    212       // To prevent negative output we take modulo to max int64.
    213       return hashed_output % std::numeric_limits<int64>::max();
    214     }
    215   }
    216 
    217  private:
    218   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
    219   const int64 num_buckets_;
    220   const uint64 hash_key_;
    221 };
    222 
    223 // ProductIterator generates cartesian products based on indices.
    224 template <typename InternalType>
    225 class ProductIterator {
    226  public:
    227   explicit ProductIterator(
    228       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    229           columns,
    230       int64 batch_index)
    231       : columns_(columns), batch_index_(batch_index) {
    232     next_permutation_.resize(columns_.size(), 0);
    233     // Sets has_next_ to false if any feature column has 0 features.
    234     has_next_ = true;
    235     for (int i = 0; i < columns_.size(); i++) {
    236       if (columns_[i]->FeatureCount(batch_index_) == 0) {
    237         has_next_ = false;
    238         break;
    239       }
    240     }
    241   }
    242 
    243   std::vector<int> Next() {
    244     std::vector<int> permutation(next_permutation_);
    245 
    246     // Generates next permutation, if available.
    247     bool carry = true;
    248     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
    249       if (carry) {
    250         next_permutation_[i] = next_permutation_[i] + 1;
    251       }
    252       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
    253         next_permutation_[i] = 0;
    254       } else {
    255         carry = false;
    256         break;
    257       }
    258     }
    259     has_next_ = !carry;
    260     return permutation;
    261   }
    262 
    263   bool HasNext() { return has_next_; }
    264 
    265  private:
    266   bool has_next_;
    267   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
    268   const int64 batch_index_;
    269   std::vector<int> next_permutation_;
    270 };
    271 
    272 template <bool HASHED_OUTPUT, typename InternalType>
    273 struct CrossTraits;
    274 
    275 template <typename InternalType>
    276 struct CrossTraits<false, InternalType> {
    277   typedef StringCrosser<InternalType> Crosser;
    278   typedef OutputUpdater<string> Updater;
    279 };
    280 
    281 template <>
    282 struct CrossTraits<true, int64> {
    283   typedef HashCrosser Crosser;
    284   typedef OutputUpdater<int64> Updater;
    285 };
    286 }  // namespace
    287 
    288 template <bool HASHED_OUTPUT, typename InternalType>
    289 class SparseCrossOp : public OpKernel {
    290  public:
    291   explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
    292     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
    293     // Read signed_hash_key_ as int64 since uint64 attributes are not
    294     // supported by REGISTER_OP.
    295     int64 signed_hash_key_;
    296     OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
    297     hash_key_ = static_cast<uint64>(signed_hash_key_);
    298   }
    299 
    300   void Compute(OpKernelContext* context) override {
    301     OpInputList indices_list_in;
    302     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
    303     OpInputList values_list_in;
    304     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
    305     OpInputList shapes_list_in;
    306     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
    307     OpInputList dense_list_in;
    308     OP_REQUIRES_OK(context,
    309                    context->input_list("dense_inputs", &dense_list_in));
    310 
    311     ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
    312                   dense_list_in);
    313 
    314     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
    315         GenerateColumnsFromInput(indices_list_in, values_list_in,
    316                                  shapes_list_in, dense_list_in);
    317 
    318     typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
    319         columns, num_buckets_, hash_key_);
    320     Tensor* indices_out;
    321     Tensor* values_out;
    322     Tensor* shape_out;
    323     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    324     std::vector<int64> output_start_indices(batch_size);
    325     CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
    326                         &shape_out, &output_start_indices);
    327 
    328     typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
    329         output_start_indices, indices_out, values_out);
    330     auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
    331       for (int b = begin; b < end; b++) {
    332         ProductIterator<InternalType> product_iterator(columns, b);
    333         int64 cross_count = 0;
    334         while (product_iterator.HasNext()) {
    335           const auto permutation = product_iterator.Next();
    336           updater.Update(b, cross_count, crosser.Generate(b, permutation));
    337           cross_count++;
    338         }
    339       }
    340     };
    341 
    342     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
    343     // TODO(zakaria): optimize kCostPerUnit
    344     const int kCostPerUnit = 5000 * indices_list_in.size();
    345     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
    346           kCostPerUnit, do_work);
    347   }
    348 
    349  private:
    350   // Validates input tensors.
    351   void ValidateInput(OpKernelContext* context,
    352                      const OpInputList& indices_list_in,
    353                      const OpInputList& values_list_in,
    354                      const OpInputList& shapes_list_in,
    355                      const OpInputList& dense_list_in) {
    356     const auto size = indices_list_in.size();
    357     // Validates indices_list_in OpInputList.
    358     for (int i = 0; i < size; i++) {
    359       OP_REQUIRES(
    360           context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
    361           errors::InvalidArgument(
    362               "Input indices should be a matrix but received shape ",
    363               indices_list_in[i].shape().DebugString(), " at position ", i));
    364       OP_REQUIRES(
    365           context, indices_list_in[i].shape().dim_size(1) == 2,
    366           errors::InvalidArgument("Expected D2 of index to be 2 got ",
    367                                   indices_list_in[i].shape().dim_size(1),
    368                                   " at position ", i));
    369     }
    370 
    371     // Validates values_list_in OpInputList.
    372     OP_REQUIRES(
    373         context, values_list_in.size() == size,
    374         errors::InvalidArgument("Expected ", size, " input values, got ",
    375                                 values_list_in.size()));
    376     for (int i = 0; i < size; i++) {
    377       OP_REQUIRES(
    378           context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
    379           errors::InvalidArgument(
    380               "Input values should be a std::vector but received shape ",
    381               values_list_in[i].shape().DebugString(), " at position ", i));
    382       OP_REQUIRES(
    383           context,
    384           indices_list_in[i].shape().dim_size(0) ==
    385               values_list_in[i].shape().dim_size(0),
    386           errors::InvalidArgument(
    387               "Expected size of values to be ",
    388               indices_list_in[i].shape().dim_size(0), " got ",
    389               values_list_in[i].shape().dim_size(0), " at position ", i));
    390     }
    391 
    392     // Validates shapes_list_in OpInputList
    393     OP_REQUIRES(
    394         context, shapes_list_in.size() == size,
    395         errors::InvalidArgument("Expected ", size, " input shapes, got ",
    396                                 shapes_list_in.size()));
    397     const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    398     for (int i = 0; i < size; i++) {
    399       OP_REQUIRES(
    400           context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
    401           errors::InvalidArgument(
    402               "Input shapes should be a std::vector but received shape ",
    403               shapes_list_in[i].shape().DebugString(), " at position ", i));
    404 
    405       OP_REQUIRES(
    406           context, shapes_list_in[i].vec<int64>().size() == 2,
    407           errors::InvalidArgument("shape should imply a 2D tensor, but got ",
    408                                   shapes_list_in[i].shape().DebugString(),
    409                                   " at position ", i));
    410       OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
    411                   errors::InvalidArgument(
    412                       "Expected batch size ", batch_size, " got ",
    413                       shapes_list_in[i].vec<int64>()(0), " at position ", i));
    414     }
    415 
    416     // Validates dense_list_in OpInputList
    417     for (int i = 0; i < dense_list_in.size(); ++i) {
    418       OP_REQUIRES(
    419           context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
    420           errors::InvalidArgument(
    421               "Dense inputs should be a matrix but received shape ",
    422               indices_list_in[i].shape().DebugString(), " at position ", i));
    423       OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
    424                   errors::InvalidArgument("Expected batch size ", batch_size,
    425                                           " got ", dense_list_in[i].dim_size(0),
    426                                           " at dense tensor ", i));
    427     }
    428   }
    429 
    430   // Calculate the batch size from either the shapes input or the dense input.
    431   int64 CalculateBatchSize(const OpInputList& shapes_list_in,
    432                            const OpInputList& dense_list_in) {
    433     if (shapes_list_in.size() > 0) {
    434       return shapes_list_in[0].vec<int64>()(0);
    435     }
    436 
    437     if (dense_list_in.size() > 0) {
    438       return dense_list_in[0].dim_size(0);
    439     }
    440 
    441     return 0;
    442   }
    443 
    444   // Generate the columns given the sparse and dense inputs.
    445   std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
    446   GenerateColumnsFromInput(const OpInputList& indices_list_in,
    447                            const OpInputList& values_list_in,
    448                            const OpInputList& shapes_list_in,
    449                            const OpInputList& dense_list_in) {
    450     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
    451     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
    452     const int64 number_of_columns = shapes_list_in.size();
    453 
    454     std::vector<std::vector<int64>> feature_counts(number_of_columns,
    455                                                    std::vector<int64>());
    456     std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
    457                                                           std::vector<int64>());
    458 
    459     ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
    460                        &feature_start_indices);
    461 
    462     columns.reserve(values_list_in.size());
    463     for (int i = 0; i < values_list_in.size(); ++i) {
    464       columns.emplace_back(new SparseTensorColumn<InternalType>(
    465           values_list_in[i], std::move(feature_counts[i]),
    466           std::move(feature_start_indices[i])));
    467     }
    468     for (int i = 0; i < dense_list_in.size(); ++i) {
    469       columns.emplace_back(
    470           new DenseTensorColumn<InternalType>(dense_list_in[i]));
    471     }
    472 
    473     return columns;
    474   }
    475 
    476   // Extracts data about the features and populates feature data.
    477   void ExtractFeatureData(
    478       const OpInputList& indices_list_in, int64 batch_size,
    479       std::vector<std::vector<int64>>* feature_counts,
    480       std::vector<std::vector<int64>>* feature_start_indices) {
    481     gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
    482     for (int b = 0; b < batch_size; b++) {
    483       for (int i = 0; i < indices_list_in.size(); i++) {
    484         const auto indices = indices_list_in[i].matrix<int64>();
    485         int64 feature_count = 0;
    486         int64 start_index = current_row[i];
    487         // Loops until we reach next batch index for current feature column.
    488         while (current_row[i] < indices_list_in[i].dim_size(0) &&
    489                indices(current_row[i], 0) == b) {
    490           feature_count++;
    491           current_row[i]++;
    492         }
    493         (*feature_counts)[i].push_back(feature_count);
    494         (*feature_start_indices)[i].push_back(start_index);
    495       }
    496     }
    497   }
    498 
    499   // Allocates output tensors with proper size and sets the shape tensor of
    500   // the output SparseTensor.
    501   // It also output_start_indices which contains the start indices for each
    502   // input in the output SparseTensor.
    503   void CreateOutputTensors(
    504       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    505           columns,
    506       int64 batch_size, OpKernelContext* context, Tensor** indices_out,
    507       Tensor** values_out, Tensor** shape_out,
    508       std::vector<int64>* output_start_indices) {
    509     // Calculates dimensions for output tensors.
    510     int64 cross_count_total = 0;
    511     int64 max_cross_count = 0;
    512     for (int64 b = 0; b < batch_size; b++) {
    513       // For each input, sets starting indices in output SparseTensor
    514       (*output_start_indices)[b] = cross_count_total;
    515       const auto cross_count = CrossCountByBatchIndex(columns, b);
    516       max_cross_count = std::max(max_cross_count, cross_count);
    517       cross_count_total += cross_count;
    518     }
    519 
    520     // Allocates tensors.
    521     OP_REQUIRES_OK(context,
    522                    context->allocate_output(
    523                        0, TensorShape({cross_count_total, 2}), indices_out));
    524     OP_REQUIRES_OK(context,
    525                    context->allocate_output(1, TensorShape({cross_count_total}),
    526                                             values_out));
    527     OP_REQUIRES_OK(context,
    528                    context->allocate_output(2, TensorShape({2}), shape_out));
    529 
    530     // Sets shape.
    531     auto shape_vec = (*shape_out)->vec<int64>();
    532     shape_vec(0) = batch_size;
    533     shape_vec(1) = max_cross_count;
    534   }
    535 
    536   // Returns number of crosses for a given batch_index
    537   int64 CrossCountByBatchIndex(
    538       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
    539           columns,
    540       int batch_index) {
    541     int64 cross_count = 1;
    542     for (int i = 0; i < columns.size(); i++) {
    543       const auto feature_count = columns[i]->FeatureCount(batch_index);
    544       // If one column is missing any feature, there won't be any cross.
    545       if (feature_count == 0) {
    546         return 0;
    547       }
    548       cross_count *= feature_count;
    549     }
    550     return cross_count;
    551   }
    552   int64 num_buckets_;
    553   uint64 hash_key_;
    554 };
    555 
    556 REGISTER_KERNEL_BUILDER(Name("SparseCross")
    557                             .Device(DEVICE_CPU)
    558                             .TypeConstraint<string>("out_type")
    559                             .TypeConstraint<string>("internal_type"),
    560                         SparseCrossOp<false, StringPiece>);
    561 
    562 REGISTER_KERNEL_BUILDER(Name("SparseCross")
    563                             .Device(DEVICE_CPU)
    564                             .TypeConstraint<string>("out_type")
    565                             .TypeConstraint<int64>("internal_type"),
    566                         SparseCrossOp<false, string>);
    567 
    568 REGISTER_KERNEL_BUILDER(Name("SparseCross")
    569                             .Device(DEVICE_CPU)
    570                             .TypeConstraint<int64>("out_type")
    571                             .TypeConstraint<string>("internal_type"),
    572                         SparseCrossOp<true, int64>);
    573 
    574 REGISTER_KERNEL_BUILDER(Name("SparseCross")
    575                             .Device(DEVICE_CPU)
    576                             .TypeConstraint<int64>("out_type")
    577                             .TypeConstraint<int64>("internal_type"),
    578                         SparseCrossOp<true, int64>);
    579 
    580 }  // namespace tensorflow
    581