1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <numeric> 16 17 #include "tensorflow/core/framework/partial_tensor_shape.h" 18 #include "tensorflow/core/framework/register_types.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/kernels/data/dataset.h" 21 #include "tensorflow/core/util/sparse/sparse_tensor.h" 22 23 namespace tensorflow { 24 25 namespace { 26 27 // See documentation in ../ops/dataset_ops.cc for a high-level 28 // description of the following op. 29 30 template <typename T> 31 class Dataset : public GraphDatasetBase { 32 public: 33 explicit Dataset(OpKernelContext* ctx, 34 const sparse::SparseTensor& sparse_tensor) 35 : GraphDatasetBase(ctx), 36 sparse_tensor_(sparse_tensor), 37 dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}), 38 shapes_({{-1, sparse_tensor.dims() - 1}, 39 {-1}, 40 {sparse_tensor.dims() - 1}}) {} 41 42 std::unique_ptr<IteratorBase> MakeIterator( 43 const string& prefix) const override { 44 return std::unique_ptr<IteratorBase>( 45 new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")})); 46 } 47 48 const DataTypeVector& output_dtypes() const override { return dtypes_; } 49 const std::vector<PartialTensorShape>& output_shapes() const override { 50 return shapes_; 51 } 52 53 string DebugString() override { 54 return "SparseTensorSliceDatasetOp::Dataset"; 55 } 56 57 protected: 58 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 59 Node** output) const override { 60 Node* indices_node; 61 TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node)); 62 Node* value_node; 63 TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.values(), &value_node)); 64 Node* dense_shape_node; 65 std::vector<int64> dense_shape; 66 dense_shape.reserve(sparse_tensor_.shape().size()); 67 for (int i = 0; i < sparse_tensor_.shape().size(); i++) 68 dense_shape.emplace_back(sparse_tensor_.shape()[i]); 69 TF_RETURN_IF_ERROR(b->AddVector(dense_shape, &dense_shape_node)); 70 AttrValue val_dtype; 71 b->BuildAttrValue(sparse_tensor_.dtype(), &val_dtype); 72 TF_RETURN_IF_ERROR( 73 b->AddDataset(this, {indices_node, value_node, dense_shape_node}, 74 {{"Tvalues", val_dtype}}, output)); 75 return Status::OK(); 76 } 77 78 private: 79 class Iterator : public DatasetIterator<Dataset<T>> { 80 public: 81 explicit Iterator(const typename Iterator::Params& params) 82 : DatasetIterator<Dataset<T>>(params), 83 num_elements_(params.dataset->sparse_tensor_.shape()[0]), 84 dense_shape_(DT_INT64, {params.dataset->sparse_tensor_.dims() - 1}), 85 group_iterable_(params.dataset->sparse_tensor_.group({0})), 86 iter_(group_iterable_.begin()) { 87 for (size_t i = 0; i < dense_shape_.NumElements(); ++i) { 88 dense_shape_.vec<int64>()(i) = 89 params.dataset->sparse_tensor_.shape()[i + 1]; 90 } 91 } 92 93 Status GetNextInternal(IteratorContext* ctx, 94 std::vector<Tensor>* out_tensors, 95 bool* end_of_sequence) override { 96 mutex_lock l(mu_); 97 if (i_ == num_elements_) { 98 *end_of_sequence = true; 99 return Status::OK(); 100 } 101 102 out_tensors->clear(); 103 out_tensors->reserve(3); 104 const int rank = Iterator::dataset()->sparse_tensor_.dims(); 105 106 if (i_ > next_non_empty_i_ && iter_ != group_iterable_.end()) { 107 // We still have elements to consume from `group_iterable_` 108 // and we have emitted all elements up to and including the 109 // current position. 110 sparse::Group group = *iter_; 111 const auto indices = group.indices(); 112 const auto values = group.values<T>(); 113 const int64 num_entries = values.size(); 114 next_non_empty_i_ = indices(0, 0); 115 116 next_indices_ = Tensor(DT_INT64, {num_entries, rank - 1}); 117 next_values_ = Tensor(DataTypeToEnum<T>::value, {num_entries}); 118 119 auto next_indices_t = next_indices_.matrix<int64>(); 120 auto next_values_t = next_values_.vec<T>(); 121 122 for (int64 i = 0; i < num_entries; ++i) { 123 for (int d = 1; d < rank; ++d) { 124 next_indices_t(i, d - 1) = indices(i, d); 125 } 126 next_values_t(i) = values(i); 127 } 128 129 ++iter_; 130 } 131 if (i_ == next_non_empty_i_) { 132 // The current position is non-empty in the input 133 // `SparseTensor`, and we have already read the value from the 134 // `GroupIterable`. 135 out_tensors->push_back(std::move(next_indices_)); 136 out_tensors->push_back(std::move(next_values_)); 137 out_tensors->push_back(dense_shape_); 138 next_non_empty_i_ = kNextNonEmptyUnknown; 139 } else { 140 DCHECK(i_ < next_non_empty_i_ || iter_ == group_iterable_.end()); 141 // The current position is empty in the input `SparseTensor`, 142 // so emit empty indices and values. 143 out_tensors->push_back(Tensor(DT_INT64, TensorShape({0, rank - 1}))); 144 out_tensors->push_back(Tensor(DataTypeToEnum<T>::value, {0})); 145 out_tensors->push_back(dense_shape_); 146 } 147 148 ++i_; 149 *end_of_sequence = false; 150 return Status::OK(); 151 } 152 153 protected: 154 Status SaveInternal(IteratorStateWriter* writer) override { 155 mutex_lock l(mu_); 156 TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_)); 157 TF_RETURN_IF_ERROR( 158 writer->WriteScalar(Iterator::full_name("iter_loc"), iter_.loc())); 159 TF_RETURN_IF_ERROR(writer->WriteScalar( 160 Iterator::full_name("next_non_empty_i_"), next_non_empty_i_)); 161 if (i_ <= next_non_empty_i_) { 162 TF_RETURN_IF_ERROR(writer->WriteTensor( 163 Iterator::full_name("next_indices_"), next_indices_)); 164 TF_RETURN_IF_ERROR(writer->WriteTensor( 165 Iterator::full_name("next_values_"), next_values_)); 166 } 167 return Status::OK(); 168 } 169 170 Status RestoreInternal(IteratorContext* ctx, 171 IteratorStateReader* reader) override { 172 mutex_lock l(mu_); 173 TF_RETURN_IF_ERROR(reader->ReadScalar(Iterator::full_name("i"), &i_)); 174 int64 iter_loc; 175 TF_RETURN_IF_ERROR( 176 reader->ReadScalar(Iterator::full_name("iter_loc"), &iter_loc)); 177 iter_ = group_iterable_.at(iter_loc); 178 TF_RETURN_IF_ERROR(reader->ReadScalar( 179 Iterator::full_name("next_non_empty_i_"), &next_non_empty_i_)); 180 if (i_ <= next_non_empty_i_) { 181 TF_RETURN_IF_ERROR(reader->ReadTensor( 182 Iterator::full_name("next_indices_"), &next_indices_)); 183 TF_RETURN_IF_ERROR(reader->ReadTensor( 184 Iterator::full_name("next_values_"), &next_values_)); 185 } 186 return Status::OK(); 187 } 188 189 private: 190 const int64 num_elements_; 191 192 Tensor dense_shape_; 193 194 mutex mu_; 195 sparse::GroupIterable group_iterable_ GUARDED_BY(mu_); 196 sparse::GroupIterable::IteratorStep iter_ GUARDED_BY(mu_); 197 int64 i_ GUARDED_BY(mu_) = 0; 198 const int64 kNextNonEmptyUnknown = -1; 199 int64 next_non_empty_i_ GUARDED_BY(mu_) = kNextNonEmptyUnknown; 200 Tensor next_indices_ GUARDED_BY(mu_); 201 Tensor next_values_ GUARDED_BY(mu_); 202 }; 203 204 const sparse::SparseTensor sparse_tensor_; 205 const DataTypeVector dtypes_; 206 const std::vector<PartialTensorShape> shapes_; 207 }; 208 209 template <typename T> 210 class SparseTensorSliceDatasetOp : public DatasetOpKernel { 211 public: 212 explicit SparseTensorSliceDatasetOp(OpKernelConstruction* ctx) 213 : DatasetOpKernel(ctx) {} 214 215 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 216 // Create a new SparseTensorSliceDatasetOp::Dataset, insert it in 217 // the step container, and return it as the output. 218 const Tensor* indices; 219 OP_REQUIRES_OK(ctx, ctx->input("indices", &indices)); 220 const Tensor* values; 221 OP_REQUIRES_OK(ctx, ctx->input("values", &values)); 222 const Tensor* dense_shape; 223 OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape)); 224 225 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()), 226 errors::InvalidArgument( 227 "Input indices should be a matrix but received shape ", 228 indices->shape().DebugString())); 229 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()), 230 errors::InvalidArgument( 231 "Input values should be a vector but received shape ", 232 indices->shape().DebugString())); 233 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()), 234 errors::InvalidArgument( 235 "Input shape should be a vector but received shape ", 236 dense_shape->shape().DebugString())); 237 238 // We currently ensure that `sparse_tensor` is ordered in the 239 // batch dimension. 240 // TODO(mrry): Investigate ways to avoid this unconditional check 241 // if we can be sure that the sparse tensor was produced in an 242 // appropriate order (e.g. by `tf.parse_example()` or a Dataset 243 // that batches elements into rows of a SparseTensor). 244 int64 previous_batch_index = -1; 245 for (int64 i = 0; i < indices->dim_size(0); ++i) { 246 int64 next_batch_index = indices->matrix<int64>()(i, 0); 247 OP_REQUIRES( 248 ctx, next_batch_index >= previous_batch_index, 249 errors::Unimplemented("The SparseTensor must be ordered in the batch " 250 "dimension; handling arbitrarily ordered input " 251 "is not currently supported.")); 252 previous_batch_index = next_batch_index; 253 } 254 gtl::InlinedVector<int64, 8> std_order(dense_shape->NumElements(), 0); 255 sparse::SparseTensor sparse_tensor( 256 *indices, *values, TensorShape(dense_shape->vec<int64>()), std_order); 257 258 *output = new Dataset<T>(ctx, sparse_tensor); 259 } 260 261 private: 262 }; 263 264 #define REGISTER_DATASET_KERNEL(type) \ 265 REGISTER_KERNEL_BUILDER(Name("SparseTensorSliceDataset") \ 266 .Device(DEVICE_CPU) \ 267 .TypeConstraint<type>("Tvalues"), \ 268 SparseTensorSliceDatasetOp<type>); 269 270 TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL); 271 #undef REGISTER_DATASET_KERNEL 272 273 } // namespace 274 275 } // namespace tensorflow 276