Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <numeric>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/resource_mgr.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_util.h"
     32 #include "tensorflow/core/framework/types.h"
     33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     34 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 
     40 using sparse::SparseTensor;
     41 
     42 class SparseTensorsMap : public ResourceBase {
     43  public:
     44   explicit SparseTensorsMap(const string& name) : name_(name), counter_(0) {}
     45 
     46   string DebugString() override { return "A SparseTensorsMap"; }
     47 
     48   typedef struct {
     49     PersistentTensor indices;
     50     PersistentTensor values;
     51     gtl::InlinedVector<int64, 8> shape;
     52   } PersistentSparseTensor;
     53 
     54   Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp,
     55                          int64* handle) {
     56     PersistentTensor persistent_ix;
     57     Tensor* ix;
     58     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
     59         sp.indices().dtype(), sp.indices().shape(), &persistent_ix, &ix));
     60     *ix = sp.indices();
     61 
     62     PersistentTensor persistent_values;
     63     Tensor* values;
     64     TF_RETURN_IF_ERROR(ctx->allocate_persistent(sp.indices().dtype(),
     65                                                 sp.indices().shape(),
     66                                                 &persistent_values, &values));
     67     *values = sp.values();
     68     {
     69       mutex_lock l(mu_);
     70       int64 unique_st_handle = counter_++;  // increment is guarded on purpose
     71       sp_tensors_[unique_st_handle] = PersistentSparseTensor{
     72           persistent_ix, persistent_values,
     73           gtl::InlinedVector<int64, 8>(sp.shape().begin(), sp.shape().end())};
     74       *handle = unique_st_handle;
     75     }
     76     return Status::OK();
     77   }
     78 
     79   Status RetrieveAndClearSparseTensors(
     80       OpKernelContext* ctx, const TTypes<int64>::ConstVec& handles,
     81       std::vector<SparseTensor>* sparse_tensors) {
     82     sparse_tensors->clear();
     83     sparse_tensors->reserve(handles.size());
     84     {
     85       mutex_lock l(mu_);
     86       for (size_t i = 0; i < handles.size(); ++i) {
     87         const int64 handle = handles(i);
     88         auto sp_iter = sp_tensors_.find(handle);
     89         if (sp_iter == sp_tensors_.end()) {
     90           return errors::InvalidArgument(
     91               "Unable to find SparseTensor: ", handle, " in map: ", name_);
     92         }
     93         const Tensor* ix = sp_iter->second.indices.AccessTensor(ctx);
     94         const Tensor* values = sp_iter->second.values.AccessTensor(ctx);
     95         const auto& shape = sp_iter->second.shape;
     96         sparse_tensors->emplace_back(*ix, *values, shape);
     97 
     98         sp_tensors_.erase(sp_iter);
     99       }
    100     }
    101 
    102     return Status::OK();
    103   }
    104 
    105  protected:
    106   ~SparseTensorsMap() override {}
    107 
    108  private:
    109   string name_;
    110 
    111   mutex mu_;
    112   int64 counter_ GUARDED_BY(mu_);
    113   std::unordered_map<int64, PersistentSparseTensor> sp_tensors_ GUARDED_BY(mu_);
    114 };
    115 
    116 class SparseTensorAccessingOp : public OpKernel {
    117  public:
    118   typedef std::function<Status(SparseTensorsMap**)> CreatorCallback;
    119 
    120   explicit SparseTensorAccessingOp(OpKernelConstruction* context)
    121       : OpKernel(context), sparse_tensors_map_(nullptr) {}
    122 
    123  protected:
    124   ~SparseTensorAccessingOp() override {
    125     if (sparse_tensors_map_) sparse_tensors_map_->Unref();
    126   }
    127 
    128   Status GetMap(OpKernelContext* ctx, bool is_writing,
    129                 SparseTensorsMap** sparse_tensors_map) {
    130     mutex_lock l(mu_);
    131 
    132     if (sparse_tensors_map_) {
    133       *sparse_tensors_map = sparse_tensors_map_;
    134       return Status::OK();
    135     }
    136 
    137     TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def(),
    138                                    is_writing /* use_node_name_as_default */));
    139 
    140     CreatorCallback sparse_tensors_map_creator = [this](SparseTensorsMap** c) {
    141       SparseTensorsMap* map = new SparseTensorsMap(cinfo_.name());
    142       *c = map;
    143       return Status::OK();
    144     };
    145 
    146     TF_RETURN_IF_ERROR(
    147         cinfo_.resource_manager()->LookupOrCreate<SparseTensorsMap>(
    148             cinfo_.container(), cinfo_.name(), &sparse_tensors_map_,
    149             sparse_tensors_map_creator));
    150 
    151     *sparse_tensors_map = sparse_tensors_map_;
    152     return Status::OK();
    153   }
    154 
    155  private:
    156   ContainerInfo cinfo_;
    157 
    158   mutex mu_;
    159   SparseTensorsMap* sparse_tensors_map_ PT_GUARDED_BY(mu_);
    160 };
    161 
    162 class AddSparseToTensorsMapOp : public SparseTensorAccessingOp {
    163  public:
    164   explicit AddSparseToTensorsMapOp(OpKernelConstruction* context)
    165       : SparseTensorAccessingOp(context) {}
    166 
    167   void Compute(OpKernelContext* context) override {
    168     const Tensor* input_indices;
    169     const Tensor* input_values;
    170     const Tensor* input_shape;
    171     SparseTensorsMap* map;
    172 
    173     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
    174     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
    175     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
    176     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
    177 
    178     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
    179                 errors::InvalidArgument(
    180                     "Input indices should be a matrix but received shape ",
    181                     input_indices->shape().DebugString()));
    182 
    183     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
    184                 errors::InvalidArgument(
    185                     "Input values should be a vector but received shape ",
    186                     input_values->shape().DebugString()));
    187 
    188     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
    189                 errors::InvalidArgument(
    190                     "Input shape should be a vector but received shape ",
    191                     input_shape->shape().DebugString()));
    192 
    193     TensorShape input_shape_object;
    194     OP_REQUIRES_OK(context,
    195                    TensorShapeUtils::MakeShape(input_shape->vec<int64>().data(),
    196                                                input_shape->NumElements(),
    197                                                &input_shape_object));
    198     SparseTensor st(*input_indices, *input_values, input_shape_object);
    199     int64 handle;
    200     OP_REQUIRES_OK(context, map->AddSparseTensor(context, st, &handle));
    201 
    202     Tensor sparse_handle(DT_INT64, TensorShape({}));
    203     auto sparse_handle_t = sparse_handle.scalar<int64>();
    204 
    205     sparse_handle_t() = handle;
    206 
    207     context->set_output(0, sparse_handle);
    208   }
    209 };
    210 
    211 REGISTER_KERNEL_BUILDER(Name("AddSparseToTensorsMap").Device(DEVICE_CPU),
    212                         AddSparseToTensorsMapOp);
    213 
    214 template <typename T>
    215 class AddManySparseToTensorsMapOp : public SparseTensorAccessingOp {
    216  public:
    217   explicit AddManySparseToTensorsMapOp(OpKernelConstruction* context)
    218       : SparseTensorAccessingOp(context) {}
    219 
    220   void Compute(OpKernelContext* context) override {
    221     const Tensor* input_indices;
    222     const Tensor* input_values;
    223     const Tensor* input_shape;
    224     SparseTensorsMap* map;
    225 
    226     OP_REQUIRES_OK(context, context->input("sparse_indices", &input_indices));
    227     OP_REQUIRES_OK(context, context->input("sparse_values", &input_values));
    228     OP_REQUIRES_OK(context, context->input("sparse_shape", &input_shape));
    229     OP_REQUIRES_OK(context, GetMap(context, true /* is_writing */, &map));
    230 
    231     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices->shape()),
    232                 errors::InvalidArgument(
    233                     "Input indices should be a matrix but received shape ",
    234                     input_indices->shape().DebugString()));
    235 
    236     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values->shape()),
    237                 errors::InvalidArgument(
    238                     "Input values should be a vector but received shape ",
    239                     input_values->shape().DebugString()));
    240 
    241     OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape->shape()),
    242                 errors::InvalidArgument(
    243                     "Input shape should be a vector but received shape ",
    244                     input_shape->shape().DebugString()));
    245 
    246     int rank = input_shape->NumElements();
    247 
    248     OP_REQUIRES(
    249         context, rank > 1,
    250         errors::InvalidArgument(
    251             "Rank of input SparseTensor should be > 1, but saw rank: ", rank));
    252 
    253     TensorShape tensor_input_shape(input_shape->vec<int64>());
    254     gtl::InlinedVector<int64, 8> std_order(rank);
    255     std::iota(std_order.begin(), std_order.end(), 0);
    256     SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
    257                           std_order);
    258 
    259     auto input_shape_t = input_shape->vec<int64>();
    260     const int64 N = input_shape_t(0);
    261 
    262     Tensor sparse_handles(DT_INT64, TensorShape({N}));
    263     auto sparse_handles_t = sparse_handles.vec<int64>();
    264 
    265     OP_REQUIRES_OK(context, input_st.IndicesValid());
    266 
    267     // We can generate the output shape proto string now, for all
    268     // minibatch entries.
    269     TensorShape output_shape;
    270     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    271                                 input_shape_t.data() + 1,
    272                                 input_shape->NumElements() - 1, &output_shape));
    273 
    274     // Get groups by minibatch dimension
    275     std::unordered_set<int64> visited;
    276     sparse::GroupIterable minibatch = input_st.group({0});
    277     for (const auto& subset : minibatch) {
    278       const int64 b = subset.group()[0];
    279       visited.insert(b);
    280       OP_REQUIRES(
    281           context, b > -1 && b < N,
    282           errors::InvalidArgument(
    283               "Received unexpected column 0 value in input SparseTensor: ", b,
    284               " < 0 or >= N (= ", N, ")"));
    285 
    286       const auto indices = subset.indices();
    287       const auto values = subset.values<T>();
    288       const int64 num_entries = values.size();
    289 
    290       Tensor output_indices = Tensor(DT_INT64, {num_entries, rank - 1});
    291       Tensor output_values = Tensor(DataTypeToEnum<T>::value, {num_entries});
    292 
    293       auto output_indices_t = output_indices.matrix<int64>();
    294       auto output_values_t = output_values.vec<T>();
    295 
    296       for (int i = 0; i < num_entries; ++i) {
    297         for (int d = 1; d < rank; ++d) {
    298           output_indices_t(i, d - 1) = indices(i, d);
    299         }
    300         output_values_t(i) = values(i);
    301       }
    302 
    303       SparseTensor st_i(output_indices, output_values, output_shape);
    304       int64 handle;
    305       OP_REQUIRES_OK(context, map->AddSparseTensor(context, st_i, &handle));
    306       sparse_handles_t(b) = handle;
    307     }
    308 
    309     // Fill in any gaps; we must provide an empty ST for batch entries
    310     // the grouper didn't find.
    311     if (visited.size() < N) {
    312       Tensor empty_indices(DT_INT64, {0, rank - 1});
    313       Tensor empty_values(DataTypeToEnum<T>::value, {0});
    314       SparseTensor empty_st(empty_indices, empty_values, output_shape);
    315 
    316       for (int64 b = 0; b < N; ++b) {
    317         // We skipped this batch entry.
    318         if (visited.find(b) == visited.end()) {
    319           int64 handle;
    320           OP_REQUIRES_OK(context,
    321                          map->AddSparseTensor(context, empty_st, &handle));
    322           sparse_handles_t(b) = handle;
    323         }
    324       }
    325     }
    326 
    327     context->set_output(0, sparse_handles);
    328   }
    329 };
    330 
    331 #define REGISTER_KERNELS(type)                              \
    332   REGISTER_KERNEL_BUILDER(Name("AddManySparseToTensorsMap") \
    333                               .Device(DEVICE_CPU)           \
    334                               .TypeConstraint<type>("T"),   \
    335                           AddManySparseToTensorsMapOp<type>)
    336 
    337 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    338 #undef REGISTER_KERNELS
    339 
    340 template <typename T>
    341 class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
    342  public:
    343   explicit TakeManySparseFromTensorsMapOp(OpKernelConstruction* context)
    344       : SparseTensorAccessingOp(context) {}
    345 
    346   void Compute(OpKernelContext* context) override {
    347     SparseTensorsMap* map = nullptr;
    348     OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
    349 
    350     const Tensor& sparse_handles = context->input(0);
    351 
    352     OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_handles.shape()),
    353                 errors::InvalidArgument(
    354                     "sparse_handles should be a vector but received shape ",
    355                     sparse_handles.shape().DebugString()));
    356 
    357     int64 N = sparse_handles.shape().dim_size(0);
    358 
    359     OP_REQUIRES(
    360         context, N > 0,
    361         errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
    362                                 "but input matrix has 0 rows"));
    363 
    364     std::vector<Tensor> indices_to_concat;
    365     std::vector<Tensor> values_to_concat;
    366     std::vector<TensorShape> shapes_to_concat;
    367 
    368     const auto& sparse_handles_t = sparse_handles.vec<int64>();
    369 
    370     std::vector<SparseTensor> sparse_tensors;
    371 
    372     OP_REQUIRES_OK(context, map->RetrieveAndClearSparseTensors(
    373                                 context, sparse_handles_t, &sparse_tensors));
    374 
    375     for (int64 i = 0; i < N; ++i) {
    376       const SparseTensor& st = sparse_tensors[i];
    377       const Tensor& output_indices = st.indices();
    378       const Tensor& output_values = st.values();
    379       const auto output_shape = st.shape();
    380 
    381       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
    382                   errors::InvalidArgument(
    383                       "Expected sparse_handles[", i,
    384                       "] to represent an index matrix but received shape ",
    385                       output_indices.shape().DebugString()));
    386       OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
    387                   errors::InvalidArgument(
    388                       "Expected sparse_handles[", i,
    389                       "] to represent a values vector but received shape ",
    390                       output_values.shape().DebugString()));
    391       OP_REQUIRES(
    392           context, DataTypeToEnum<T>::value == output_values.dtype(),
    393           errors::InvalidArgument(
    394               "Requested SparseTensor of type ",
    395               DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
    396               "].values.dtype() == ", DataTypeString(output_values.dtype())));
    397 
    398       int64 num_entries = output_indices.dim_size(0);
    399       OP_REQUIRES(context, num_entries == output_values.dim_size(0),
    400                   errors::InvalidArgument(
    401                       "Expected row counts of SparseTensor[", i,
    402                       "].indices and SparseTensor[", i,
    403                       "].values to match but they do not: ", num_entries,
    404                       " vs. ", output_values.dim_size(0)));
    405       int rank = output_indices.dim_size(1);
    406       OP_REQUIRES(
    407           context, rank == output_shape.size(),
    408           errors::InvalidArgument("Expected column counts of SparseTensor[", i,
    409                                   "].indices to match size of SparseTensor[", i,
    410                                   "].shape "
    411                                   "but they do not: ",
    412                                   rank, " vs. ", output_shape.size()));
    413 
    414       // Now we expand each SparseTensors' indices and shape by
    415       // prefixing a dimension
    416       Tensor expanded_indices(
    417           DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
    418       Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
    419       const auto& output_indices_t = output_indices.matrix<int64>();
    420       auto expanded_indices_t = expanded_indices.matrix<int64>();
    421       auto expanded_shape_t = expanded_shape.vec<int64>();
    422       expanded_indices_t.chip<1>(0).setZero();
    423       Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
    424       Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
    425       expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
    426       expanded_shape_t(0) = 1;
    427       // TODO: copy shape from TensorShape to &expanded_shape_t(1)
    428       // std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
    429       for (int i = 0; i < rank; ++i) {
    430         expanded_shape_t(i + 1) = output_shape[i];
    431       }
    432       TensorShape expanded_tensor_shape(expanded_shape_t);
    433 
    434       indices_to_concat.push_back(std::move(expanded_indices));
    435       values_to_concat.push_back(output_values);
    436       shapes_to_concat.push_back(std::move(expanded_tensor_shape));
    437     }
    438 
    439     int rank = -1;
    440     for (int i = 0; i < N; ++i) {
    441       if (rank < 0) rank = shapes_to_concat[i].dims();
    442       OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
    443                   errors::InvalidArgument(
    444                       "Inconsistent rank across SparseTensors: rank prior to "
    445                       "SparseTensor[",
    446                       i, "] was: ", rank, " but rank of SparseTensor[", i,
    447                       "] is: ", shapes_to_concat[i].dims()));
    448     }
    449 
    450     // SparseTensor::Concat requires consistent shape for all but the
    451     // primary order dimension (dimension 0 in this case).  So we get
    452     // the maximum value across all the input SparseTensors for each
    453     // dimension and use that.
    454     TensorShape preconcat_shape(shapes_to_concat[0]);
    455     for (int i = 0; i < N; ++i) {
    456       for (int d = 0; d < rank; ++d) {
    457         preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
    458                                             shapes_to_concat[i].dim_size(d)));
    459       }
    460     }
    461 
    462     // Dimension 0 is the primary dimension.
    463     gtl::InlinedVector<int64, 8> std_order(rank);
    464     std::iota(std_order.begin(), std_order.end(), 0);
    465 
    466     std::vector<SparseTensor> tensors_to_concat;
    467     tensors_to_concat.reserve(N);
    468     for (int i = 0; i < N; ++i) {
    469       tensors_to_concat.emplace_back(std::move(indices_to_concat[i]),
    470                                      std::move(values_to_concat[i]),
    471                                      preconcat_shape, std_order);
    472     }
    473 
    474     SparseTensor output(SparseTensor::Concat<T>(tensors_to_concat));
    475 
    476     Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
    477 
    478     std::copy_n(output.shape().data(), output.dims(),
    479                 final_output_shape.vec<int64>().data());
    480 
    481     context->set_output(0, output.indices());
    482     context->set_output(1, output.values());
    483     context->set_output(2, final_output_shape);
    484   }
    485 };
    486 
    487 #define REGISTER_KERNELS(type)                                 \
    488   REGISTER_KERNEL_BUILDER(Name("TakeManySparseFromTensorsMap") \
    489                               .Device(DEVICE_CPU)              \
    490                               .TypeConstraint<type>("dtype"),  \
    491                           TakeManySparseFromTensorsMapOp<type>)
    492 
    493 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    494 #undef REGISTER_KERNELS
    495 
    496 }  // namespace tensorflow
    497