Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <numeric>
     16 
     17 #include "tensorflow/core/framework/partial_tensor_shape.h"
     18 #include "tensorflow/core/framework/register_types.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/kernels/data/dataset.h"
     21 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     22 
     23 namespace tensorflow {
     24 
     25 namespace {
     26 
     27 // See documentation in ../ops/dataset_ops.cc for a high-level
     28 // description of the following op.
     29 
     30 template <typename T>
     31 class Dataset : public GraphDatasetBase {
     32  public:
     33   explicit Dataset(OpKernelContext* ctx,
     34                    const sparse::SparseTensor& sparse_tensor)
     35       : GraphDatasetBase(ctx),
     36         sparse_tensor_(sparse_tensor),
     37         dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}),
     38         shapes_({{-1, sparse_tensor.dims() - 1},
     39                  {-1},
     40                  {sparse_tensor.dims() - 1}}) {}
     41 
     42   std::unique_ptr<IteratorBase> MakeIterator(
     43       const string& prefix) const override {
     44     return std::unique_ptr<IteratorBase>(
     45         new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")}));
     46   }
     47 
     48   const DataTypeVector& output_dtypes() const override { return dtypes_; }
     49   const std::vector<PartialTensorShape>& output_shapes() const override {
     50     return shapes_;
     51   }
     52 
     53   string DebugString() override {
     54     return "SparseTensorSliceDatasetOp::Dataset";
     55   }
     56 
     57  protected:
     58   Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
     59                             Node** output) const override {
     60     Node* indices_node;
     61     TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node));
     62     Node* value_node;
     63     TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.values(), &value_node));
     64     Node* dense_shape_node;
     65     std::vector<int64> dense_shape;
     66     dense_shape.reserve(sparse_tensor_.shape().size());
     67     for (int i = 0; i < sparse_tensor_.shape().size(); i++)
     68       dense_shape.emplace_back(sparse_tensor_.shape()[i]);
     69     TF_RETURN_IF_ERROR(b->AddVector(dense_shape, &dense_shape_node));
     70     AttrValue val_dtype;
     71     b->BuildAttrValue(sparse_tensor_.dtype(), &val_dtype);
     72     TF_RETURN_IF_ERROR(
     73         b->AddDataset(this, {indices_node, value_node, dense_shape_node},
     74                       {{"Tvalues", val_dtype}}, output));
     75     return Status::OK();
     76   }
     77 
     78  private:
     79   class Iterator : public DatasetIterator<Dataset<T>> {
     80    public:
     81     explicit Iterator(const typename Iterator::Params& params)
     82         : DatasetIterator<Dataset<T>>(params),
     83           num_elements_(params.dataset->sparse_tensor_.shape()[0]),
     84           dense_shape_(DT_INT64, {params.dataset->sparse_tensor_.dims() - 1}),
     85           group_iterable_(params.dataset->sparse_tensor_.group({0})),
     86           iter_(group_iterable_.begin()) {
     87       for (size_t i = 0; i < dense_shape_.NumElements(); ++i) {
     88         dense_shape_.vec<int64>()(i) =
     89             params.dataset->sparse_tensor_.shape()[i + 1];
     90       }
     91     }
     92 
     93     Status GetNextInternal(IteratorContext* ctx,
     94                            std::vector<Tensor>* out_tensors,
     95                            bool* end_of_sequence) override {
     96       mutex_lock l(mu_);
     97       if (i_ == num_elements_) {
     98         *end_of_sequence = true;
     99         return Status::OK();
    100       }
    101 
    102       out_tensors->clear();
    103       out_tensors->reserve(3);
    104       const int rank = Iterator::dataset()->sparse_tensor_.dims();
    105 
    106       if (i_ > next_non_empty_i_ && iter_ != group_iterable_.end()) {
    107         // We still have elements to consume from `group_iterable_`
    108         // and we have emitted all elements up to and including the
    109         // current position.
    110         sparse::Group group = *iter_;
    111         const auto indices = group.indices();
    112         const auto values = group.values<T>();
    113         const int64 num_entries = values.size();
    114         next_non_empty_i_ = indices(0, 0);
    115 
    116         next_indices_ = Tensor(DT_INT64, {num_entries, rank - 1});
    117         next_values_ = Tensor(DataTypeToEnum<T>::value, {num_entries});
    118 
    119         auto next_indices_t = next_indices_.matrix<int64>();
    120         auto next_values_t = next_values_.vec<T>();
    121 
    122         for (int64 i = 0; i < num_entries; ++i) {
    123           for (int d = 1; d < rank; ++d) {
    124             next_indices_t(i, d - 1) = indices(i, d);
    125           }
    126           next_values_t(i) = values(i);
    127         }
    128 
    129         ++iter_;
    130       }
    131       if (i_ == next_non_empty_i_) {
    132         // The current position is non-empty in the input
    133         // `SparseTensor`, and we have already read the value from the
    134         // `GroupIterable`.
    135         out_tensors->push_back(std::move(next_indices_));
    136         out_tensors->push_back(std::move(next_values_));
    137         out_tensors->push_back(dense_shape_);
    138         next_non_empty_i_ = kNextNonEmptyUnknown;
    139       } else {
    140         DCHECK(i_ < next_non_empty_i_ || iter_ == group_iterable_.end());
    141         // The current position is empty in the input `SparseTensor`,
    142         // so emit empty indices and values.
    143         out_tensors->push_back(Tensor(DT_INT64, TensorShape({0, rank - 1})));
    144         out_tensors->push_back(Tensor(DataTypeToEnum<T>::value, {0}));
    145         out_tensors->push_back(dense_shape_);
    146       }
    147 
    148       ++i_;
    149       *end_of_sequence = false;
    150       return Status::OK();
    151     }
    152 
    153    protected:
    154     Status SaveInternal(IteratorStateWriter* writer) override {
    155       mutex_lock l(mu_);
    156       TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_));
    157       TF_RETURN_IF_ERROR(
    158           writer->WriteScalar(Iterator::full_name("iter_loc"), iter_.loc()));
    159       TF_RETURN_IF_ERROR(writer->WriteScalar(
    160           Iterator::full_name("next_non_empty_i_"), next_non_empty_i_));
    161       if (i_ <= next_non_empty_i_) {
    162         TF_RETURN_IF_ERROR(writer->WriteTensor(
    163             Iterator::full_name("next_indices_"), next_indices_));
    164         TF_RETURN_IF_ERROR(writer->WriteTensor(
    165             Iterator::full_name("next_values_"), next_values_));
    166       }
    167       return Status::OK();
    168     }
    169 
    170     Status RestoreInternal(IteratorContext* ctx,
    171                            IteratorStateReader* reader) override {
    172       mutex_lock l(mu_);
    173       TF_RETURN_IF_ERROR(reader->ReadScalar(Iterator::full_name("i"), &i_));
    174       int64 iter_loc;
    175       TF_RETURN_IF_ERROR(
    176           reader->ReadScalar(Iterator::full_name("iter_loc"), &iter_loc));
    177       iter_ = group_iterable_.at(iter_loc);
    178       TF_RETURN_IF_ERROR(reader->ReadScalar(
    179           Iterator::full_name("next_non_empty_i_"), &next_non_empty_i_));
    180       if (i_ <= next_non_empty_i_) {
    181         TF_RETURN_IF_ERROR(reader->ReadTensor(
    182             Iterator::full_name("next_indices_"), &next_indices_));
    183         TF_RETURN_IF_ERROR(reader->ReadTensor(
    184             Iterator::full_name("next_values_"), &next_values_));
    185       }
    186       return Status::OK();
    187     }
    188 
    189    private:
    190     const int64 num_elements_;
    191 
    192     Tensor dense_shape_;
    193 
    194     mutex mu_;
    195     sparse::GroupIterable group_iterable_ GUARDED_BY(mu_);
    196     sparse::GroupIterable::IteratorStep iter_ GUARDED_BY(mu_);
    197     int64 i_ GUARDED_BY(mu_) = 0;
    198     const int64 kNextNonEmptyUnknown = -1;
    199     int64 next_non_empty_i_ GUARDED_BY(mu_) = kNextNonEmptyUnknown;
    200     Tensor next_indices_ GUARDED_BY(mu_);
    201     Tensor next_values_ GUARDED_BY(mu_);
    202   };
    203 
    204   const sparse::SparseTensor sparse_tensor_;
    205   const DataTypeVector dtypes_;
    206   const std::vector<PartialTensorShape> shapes_;
    207 };
    208 
    209 template <typename T>
    210 class SparseTensorSliceDatasetOp : public DatasetOpKernel {
    211  public:
    212   explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx)
    213       : DatasetOpKernel(ctx) {}
    214 
    215   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
    216     // Create a new SparseTensorSliceDatasetOp::Dataset, insert it in
    217     // the step container, and return it as the output.
    218     const Tensor* indices;
    219     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices));
    220     const Tensor* values;
    221     OP_REQUIRES_OK(ctx, ctx->input("values", &values));
    222     const Tensor* dense_shape;
    223     OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape));
    224 
    225     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()),
    226                 errors::InvalidArgument(
    227                     "Input indices should be a matrix but received shape ",
    228                     indices->shape().DebugString()));
    229     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
    230                 errors::InvalidArgument(
    231                     "Input values should be a vector but received shape ",
    232                     indices->shape().DebugString()));
    233     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()),
    234                 errors::InvalidArgument(
    235                     "Input shape should be a vector but received shape ",
    236                     dense_shape->shape().DebugString()));
    237 
    238     // We currently ensure that `sparse_tensor` is ordered in the
    239     // batch dimension.
    240     // TODO(mrry): Investigate ways to avoid this unconditional check
    241     // if we can be sure that the sparse tensor was produced in an
    242     // appropriate order (e.g. by `tf.parse_example()` or a Dataset
    243     // that batches elements into rows of a SparseTensor).
    244     int64 previous_batch_index = -1;
    245     for (int64 i = 0; i < indices->dim_size(0); ++i) {
    246       int64 next_batch_index = indices->matrix<int64>()(i, 0);
    247       OP_REQUIRES(
    248           ctx, next_batch_index >= previous_batch_index,
    249           errors::Unimplemented("The SparseTensor must be ordered in the batch "
    250                                 "dimension; handling arbitrarily ordered input "
    251                                 "is not currently supported."));
    252       previous_batch_index = next_batch_index;
    253     }
    254     gtl::InlinedVector<int64, 8> std_order(dense_shape->NumElements(), 0);
    255     sparse::SparseTensor sparse_tensor(
    256         *indices, *values, TensorShape(dense_shape->vec<int64>()), std_order);
    257 
    258     *output = new Dataset<T>(ctx, sparse_tensor);
    259   }
    260 
    261  private:
    262 };
    263 
    264 #define REGISTER_DATASET_KERNEL(type)                           \
    265   REGISTER_KERNEL_BUILDER(Name("SparseTensorSliceDataset")      \
    266                               .Device(DEVICE_CPU)               \
    267                               .TypeConstraint<type>("Tvalues"), \
    268                           SparseTensorSliceDatasetOp<type>);
    269 
    270 TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL);
    271 #undef REGISTER_DATASET_KERNEL
    272 
    273 }  // namespace
    274 
    275 }  // namespace tensorflow
    276